diff --git a/CODEOWNERS b/CODEOWNERS index 007a304c3e706ce968576ec8979c08f1a3bcc552..b9f0313cc6d59d3fbdcd014e1a528126d863075a 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -45,7 +45,7 @@ # /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh # /tensorflow/contrib/slim/ @sguada @thenbasilmanran # /tensorflow/contrib/stateless/ @girving -# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst +# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank # /tensorflow/contrib/testing/ @dandelionmane # /tensorflow/contrib/timeseries/ @allenlavoie # /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu diff --git a/README.md b/README.md index 916e5200b29841028652c861c49dbb3650baea3c..e1a50c87e26d493ba3ac760f357905d89aa40dab 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,17 @@ ----------------- -| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------| -| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | + +| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | +|-----------------|---------------------|------------------|-------------------|---------------|---------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture lets you deploy computation to one +between them. This flexible architecture enables you to deploy computation to one or more CPUs or GPUs in a desktop, server, or mobile device without rewriting -code. TensorFlow also includes TensorBoard, a data visualization toolkit. +code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), a data visualization toolkit. TensorFlow was originally developed by researchers and engineers working on the Google Brain team within Google's Machine Intelligence Research @@ -21,19 +22,9 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -**If you want to contribute to TensorFlow, be sure to review the [contribution -guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's -[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to -uphold this code.** - -**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions -and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** - -The TensorFlow project strives to abide by generally accepted best practices in open-source software development: - -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) +Keep up to date with release announcements and security updates by +subscribing to +[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -75,10 +66,27 @@ $ python >>> sess.close() ``` +## Contribution guidelines + +**If you want to contribute to TensorFlow, be sure to review the [contribution +guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's +[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to +uphold this code.** + +**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for +tracking requests and bugs. So please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions +and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** + +The TensorFlow project strives to abide by generally accepted best practices in open-source software development: + +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + ## For more information * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) +* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) diff --git a/RELEASE.md b/RELEASE.md index 0720a8c639f8ab87214b11f6a8092b432b916853..2717c75740aeea7821fb6c57dfc85908e86e9d51 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,132 @@ +# Release 1.8.0 + +## Major Features And Improvements +* Can now pass `tf.contrib.distribute.MirroredStrategy()` to `tf.estimator.RunConfig()` to run an Estimator model on multiple GPUs on one machine. +* Add `tf.contrib.data.prefetch_to_device()`, which supports prefetching to GPU memory. +* Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor. +* Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability. +* `tf.contrib.bayesflow` is moving out to it's own repo. +* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication. + +## Bug Fixes and Other Changes +* `tf.data`: + * Add `tf.contrib.data.prefetch_to_device`, which enables prefetching dataset elements to GPU memory. + * Add `tf.contrib.data.AUTOTUNE`, which allows the tf.data runtime to automatically tune the prefetch buffer sizes based on your system and environment. + * Add `tf.contrib.data.make_csv_dataset` for building datasets of CSV files. +* Eager Execution: + * With eager execution Datasets can now be used as standard python iterators (`for batch in dataset:`). Both `Dataset.__iter__()` and `Dataset.make_one_shot_iterator()` can now be used to create iterators when eager execution is enabled. + * Automatic device placement has been enabled (i.e., use a GPU if available automatically, without requiring an explicit `with tf.device(“/gpu:0”)`) (Fixes #14133) + * `tf.GradientTape` has moved out of contrib. +* `tf.keras`: + * Added the fashion mnist dataset. + * New data preprocessing functions: `image/random_brightness`, `sequence/TimeseriesGenerator`, and `text/hashing_trick`. +* Accelerated Linear Algebra (XLA): + * Select and scatter in reference util and evaluator now use lexicographical order to break ties. +* TensorFlow Debugger (tfdbg) CLI: + * During tensor-filter operations, allow exclusion of nodes by regular expressions. + * Fix spurious background colors in some text terminals. +* `tf.contrib`: + * Add meta-distribution BatchReshape which reshapes batch dimensions. + * `tf.contrib.layers.recompute_grad` works for explicit gradient checkpointing on TPU. + * Add `tf.contrib.framework.argsort`. + * Allow `DNNBoostedTreeCombinedEstimator` to work with core versions of feature columns and losses. + * Add non-linear image warping ops: `tf.contrib.image.sparse_image_warp`, `tf.contrib.image.dense_image_warp`, and `tf.contrib.image.interpolate_spline`. + * Fix bug in `tf.contrib.opt.MultitaskOptimizerWrapper` where types of tensors were mismatched. +* Other: + * Low-level graph construction now calls the TensorFlow C API. This change should be invisible to most users, but can be disabled by setting the environment variable `TF_C_API_GRAPH_CONSTRUCTION=0` in this release. Future releases will remove the ability to disable this change. Please [file a bug](https://github.com/tensorflow/tensorflow/issues/new) if you find yourself using this escape hatch. + * Add description of shapes and a pointer to tutorial notebook in `tf.distributions.Distribution`. + * Update scatter operations: + * Add `tf.scatter_min` and `tf.scatter_max` + * Extend scatter operations to work with a scalar update parameter. + * Move cuDNN RNN ops to core for use in TensorFlow codebase only. + * Add `float64` support for `Conv2d`, `Conv2dBackpropInput`, and `Conv2dBackpropFilter`. + * Add `float64` support for `AvgPool`/`AvgPoolGrad`. + * Make graph name scope thread local so that they work correctly in multi-threaded environments. + * Update nsync synchronization library to avoid slow primitives on Linux. + * Removed need to put nsync/public on C include path when building custom ops. + * Add `tf.image.psnr`, `tf.image.ssim`, `tf.image.ssim_multiscale`, `tf.image.image_gradients`, `tf.image.sobel_edges`. + * Add links to https://js.tensorflow.org. + * Fix non-uniformity of orthogonal matrices. + * Fix bug where multi-image Estimator eval summaries were not displayed correctly. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu + + +# Release 1.7.0 + +## Major Features And Improvements +* Eager mode is moving out of contrib, try `tf.enable_eager_execution()`. +* Graph rewrites emulating fixed-point quantization compatible with TensorFlow Lite, supported by new `tf.contrib.quantize` package. +* Easily customize gradient computation with `tf.custom_gradient`. +* [TensorBoard Debugger Plugin](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/debugger/README.md), the graphical user interface (GUI) of TensorFlow Debugger (tfdbg), is now in alpha. +* Experimental support for reading a sqlite database as a `Dataset` with new `tf.contrib.data.SqlDataset`. +* Distributed Mutex / CriticalSection added to `tf.contrib.framework.CriticalSection`. +* Better text processing with `tf.regex_replace`. +* Easy, efficient sequence input with `tf.contrib.data.bucket_by_sequence_length` +* Initial support for `tf.contrib.tensorrt` that enables native TensorRT in + TensorFlow. + +## Bug Fixes and Other Changes +* Accelerated Linear Algebra (XLA): + * Add `MaxPoolGradGrad` support for XLA + * CSE pass from Tensorflow is now disabled in XLA. +* `tf.data`: + * `tf.data.Dataset` + * Add support for building C++ Dataset op kernels as external libraries, using the `tf.load_op_library()` mechanism. + * `Dataset.list_files()` now shuffles its output by default. + * `Dataset.shuffle(..., seed=tf.constant(0, dtype=tf.int64))` now yields the same sequence of elements as `Dataset.shuffle(..., seed=0)`. + * Add `num_parallel_reads` argument to `tf.data.TFRecordDataset`. +* `tf.contrib`: + * `tf.contrib.bayesflow.halton_sequence` now supports randomization. + * Add support for scalars in `tf.contrib.all_reduce`. + * Add `effective_sample_size` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `potential_scale_reduction` to `tf.contrib.bayesflow.mcmc_diagnostics`. + * Add `BatchNormalization`, `Kumaraswamy` bijectors. + * Deprecate `tf.contrib.learn`. Please check contrib/learn/README.md for instructions on how to convert existing code. + * `tf.contrib.data` + * Remove deprecated `tf.contrib.data.Dataset`, `tf.contrib.data.Iterator`, `tf.contrib.data.FixedLengthRecordDataset`, `tf.contrib.data.TextLineDataset`, and `tf.contrib.data.TFRecordDataset` classes. + * Added `bucket_by_sequence_length`, `sliding_window_batch`, and `make_batched_features_dataset` + * Remove unmaintained `tf.contrib.ndlstm`. You can find it externally at https://github.com/tmbarchive/tfndlstm. + * Moved most of `tf.contrib.bayesflow` to its own repo: `tfp` +* Other: + * tf.py_func now reports the full stack trace if an exception occurs. + * Integrate `TPUClusterResolver` with GKE's integration for Cloud TPUs. + * Add a library for statistical testing of samplers. + * Add Helpers to stream data from the GCE VM to a Cloud TPU. + * Integrate ClusterResolvers with TPUEstimator. + * Unify metropolis_hastings interface with HMC kernel. + * Move LIBXSMM convolutions to a separate --define flag so that they are disabled by default. + * Fix `MomentumOptimizer` lambda. + * Reduce `tfp.layers` boilerplate via programmable docstrings. + * Add `auc_with_confidence_intervals`, a method for computing the AUC and confidence interval with linearithmic time complexity. + * `regression_head` now accepts customized link function, to satisfy the usage that user can define their own link function if the `array_ops.identity` does not meet the requirement. + * Fix `initialized_value` and `initial_value` behaviors for `ResourceVariables` created from `VariableDef` protos. + * Add TensorSpec to represent the specification of Tensors. + * Constant folding pass is now deterministic. + * Support `float16` `dtype` in `tf.linalg.*`. + * Add `tf.estimator.export.TensorServingInputReceiver` that allows `tf.estimator.Estimator.export_savedmodel` to pass raw tensors to model functions. + +## Deprecations + +* TensorFlow 1.7 may be the last time we support Cuda versions below 8.0. + Starting with TensorFlow 1.8 release, 8.0 will be the minimum supported + version. +* TensorFlow 1.7 may be the last time we support cuDNN versions below 6.0. + Starting with TensorFlow 1.8 release, 6.0 will be the minimum supported + version. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abe, Alistair Low, Andy Kernahan, Appledore, Ben, Ben Barsdell, Boris Pfahringer, Brad Wannow, Brett Koonce, Carl Thomé, cclauss, Chengzhi Chen, Chris Drake, Christopher Yeh, Clayne Robison, Codrut Grosu, Daniel Trebbien, Danny Goodman, David Goodwin, David Norman, Deron Eriksson, Donggeon Lim, Donny Viszneki, DosLin, DylanDmitri, Francisco Guerrero, Fred Reiss, gdh1995, Giuseppe, Glenn Weidner, gracehoney, Guozhong Zhuang, Haichen "Hc" Li, Harald Husum, harumitsu.nobuta, Henry Spivey, hsm207, Jekyll Song, Jerome, Jiongyan Zhang, jjsjann123, John Sungjin Park, Johnson145, JoshVarty, Julian Wolff, Jun Wang, June-One, Kamil Sindi, Kb Sriram, Kdavis-Mozilla, Kenji, lazypanda1, Liang-Chi Hsieh, Loo Rong Jie, Mahesh Bhosale, MandarJKulkarni, ManHyuk, Marcus Ong, Marshal Hayes, Martin Pool, matthieudelaro, mdfaijul, mholzel, Michael Zhou, Ming Li, Minmin Sun, Myungjoo Ham, MyungsungKwak, Naman Kamra, Peng Yu, Penghao Cen, Phil, Raghuraman-K, resec, Rohin Mohanadas, Sandeep N Gupta, Scott Tseng, seaotterman, Seo Sanghyeon, Sergei Lebedev, Ted Chang, terrytangyuan, Tim H, tkunic, Tod, vihanjain, Yan Facai (颜发才), Yin Li, Yong Tang, Yukun Chen, Yusuke Yamada + + + # Release 1.6.0 ## Breaking Changes @@ -21,7 +150,7 @@ newcomers. * Other: * Add `tf.contrib.distributions.Kumaraswamy`. * `RetryingFileSystem::FlushCaches()` calls the base FileSystem's `FlushCaches()`. - * Add auto_correlation to distributions. + * Add `auto_correlation` to distributions. * Add `tf.contrib.distributions.Autoregressive`. * Add SeparableConv1D layer. * Add convolutional Flipout layers. @@ -31,12 +160,12 @@ newcomers. * Output variance over trees predictions for classifications tasks. * For `pt` and `eval` commands, allow writing tensor values to filesystem as numpy files. * gRPC: Propagate truncated errors (instead of returning gRPC internal error). - * Augment parallel_interleave to support 2 kinds of prefetching. + * Augment `parallel_interleave` to support 2 kinds of prefetching. * Improved XLA support for C64-related ops log, pow, atan2, tanh. * Add probabilistic convolutional layers. ## API Changes -* Introducing prepare_variance boolean with default setting to False for backward compatibility. +* Introducing `prepare_variance` boolean with default setting to False for backward compatibility. * Move `layers_dense_variational_impl.py` to `layers_dense_variational.py`. ## Known Bugs @@ -96,27 +225,6 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 * Starting from 1.6 release, our prebuilt binaries will use AVX instructions. This may break TF on older CPUs. -## Known Bugs -* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or - `CUDA_ILLEGAL_ADDRESS` failures. - - Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 - and CUDA 9.1 sometimes does not properly compute the carry bit when - decomposing 64-bit address calculations with large offsets (e.g. `load [x + - large_constant]`) into 32-bit arithmetic in SASS. - - As a result, these versions of `ptxas` miscompile most XLA programs which use - more than 4GB of temp memory. This results in garbage results and/or - `CUDA_ERROR_ILLEGAL_ADDRESS` failures. - - A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a - fix for CUDA 9.0.x. Until the fix is available, the only workaround is to - [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x - or disable XLA:GPU. - - TensorFlow will print a warning if you use XLA:GPU with a known-bad version of - CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. - ## Major Features And Improvements * [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) preview version is now available. diff --git a/tensorflow/SECURITY.md b/SECURITY.md similarity index 90% rename from tensorflow/SECURITY.md rename to SECURITY.md index 6ddac1f964dfba3afd240441e2a036bc24ee6d91..a5ce3a62ee202f6e7d83f0fedc2777d9c88ba9b5 100644 --- a/tensorflow/SECURITY.md +++ b/SECURITY.md @@ -6,7 +6,7 @@ report vulnerabilities in TensorFlow. ## TensorFlow models are programs -TensorFlow's runtime system interprets and executes programs. What machine +TensorFlow's runtime system interprets and executes programs. What machine learning practitioners term [**models**](https://developers.google.com/machine-learning/glossary/#model) are expressed as programs that TensorFlow executes. TensorFlow programs are encoded @@ -28,12 +28,12 @@ data you supply to TensorFlow to train a model, or to use a model to run inference on the data. **TensorFlow models are programs, and need to be treated as such from a security -perspective.** +perspective.** ## Running untrusted models As a general rule: **Always** execute untrusted models inside a sandbox (e.g., -[nsjail](https://github.com/google/nsjail)). +[nsjail](https://github.com/google/nsjail)). There are several ways in which a model could become untrusted. Obviously, if an untrusted party supplies TensorFlow kernels, arbitrary code may be executed. @@ -109,11 +109,11 @@ graphs known to the `ModelServer`. This means that an attacker may run graphs using untrusted inputs as described above, but they would not be able to execute arbitrary graphs. It is possible to safely expose a `ModelServer` directly to an untrusted network, **but only if the graphs it is configured to -use have been carefully audited to be safe**. +use have been carefully audited to be safe**. Similar to best practices for other servers, we recommend running any `ModelServer` with appropriate privileges (i.e., using a separate user with -reduced permisisons). In the spirit of defense in depth, we recommend +reduced permissions). In the spirit of defense in depth, we recommend authenticating requests to any TensorFlow server connected to an untrusted network, as well as sandboxing the server to minimize the adverse effects of any breach. @@ -129,11 +129,11 @@ with specially crafted inputs. ### What is a vulnerability? Given TensorFlow's flexibility, it is possible to specify computation graphs -which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models +which exhibit unexpected or unwanted behavior. The fact that TensorFlow models can perform arbitrary computations means that they may read and write files, communicate via the network, produce deadlocks and infinite loops, or run out of memory. It is only when these behaviors are outside the specifications of the -operations involved that such behavior is a vulnerability. +operations involved that such behavior is a vulnerability. A `FileWriter` writing a file is not unexpected behavior and therefore is not a vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution @@ -170,6 +170,17 @@ Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of the progress being made towards a fix and announcement. +In addition, please include the following information along with your report: + +* Your name and affiliation (if any). +* A description the technical details of the vulnerabilities. It is very + important to let us know how we can reproduce your findings. +* An explanation who can exploit this vulnerability, and what they gain when + doing so -- write an attack scenario. This will help us evaluate your report + quickly, especially if the issue is complex. +* Whether this vulnerability public or known to third parties. If it is, please + provide details. + If you believe that an existing (public) issue is security-related, please send an email to `security@tensorflow.org`. The email should include the issue ID and a short description of why it should be handled according to this security @@ -233,7 +244,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known vulnerabilities -| Type | Versions affected | Reported by | Additional Information | -|------|:-----------------:|---------------------------------------| -| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +| Type | Versions affected | Reported by | Additional Information | +|--------------------|:-----------------:|-----------------------|-----------------------------| +| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/WORKSPACE b/WORKSPACE index 1e38a9a8cd754886fc5232531816b875de0879a3..4ddfb9a3832ea1ea639ace887e1d601bdd857086 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657", - strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f", + sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", + strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 ], ) @@ -14,6 +14,12 @@ load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") closure_repositories() +# We must check the bazel version before trying to parse any other BUILD +# files, in case the parsing of those build files depends on the bazel +# version we require here. +load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") +check_bazel_version_at_least("0.10.0") + load("//tensorflow:workspace.bzl", "tf_workspace") # Uncomment and update the paths in these entries to build the Android demo. diff --git a/configure b/configure index 9c21d2b03a27714f05094667691e74c16fa89f35..66b66ba54ed68a9aa0ee556f84f68c3a83a495ab 100755 --- a/configure +++ b/configure @@ -8,7 +8,8 @@ if [ -z "$PYTHON_BIN_PATH" ]; then fi # Set all env variables -"$PYTHON_BIN_PATH" configure.py +CONFIGURE_DIR=$(dirname "$0") +"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" echo "Configuration finished" diff --git a/configure.py b/configure.py index 3aa1a3e956c6a559b89cdeb593a96a95188c32ae..b745e374a2baaffec73f9f9382e1bab322e7f0fd 100644 --- a/configure.py +++ b/configure.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import errno import os import platform @@ -32,18 +33,15 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '.tf_configure.bazelrc') -_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'WORKSPACE') _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' +_DEFAULT_NCCL_VERSION = '1.3' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) -_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' @@ -51,6 +49,11 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 +_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) +_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' +_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) +_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') + class UserInputError(Exception): pass @@ -119,22 +122,6 @@ def sed_in_place(filename, old, new): f.write(newdata) -def remove_line_with(filename, token): - """Remove lines that contain token from file. - - Args: - filename: string for filename. - token: string token to check if to remove a line from file or not. - """ - with open(filename, 'r') as f: - filedata = f.read() - - with open(filename, 'w') as f: - for line in filedata.strip().split('\n'): - if token not in line: - f.write(line + '\n') - - def write_to_bazelrc(line): with open(_TF_BAZELRC, 'a') as f: f.write(line + '\n') @@ -239,31 +226,34 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --force_python=py%s' % python_major_version) - write_to_bazelrc('build --host_force_python=py%s' % python_major_version) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open('tools/python_bin_path.sh', 'w') as f: + with open(os.path.join( + _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(): +def reset_tf_configure_bazelrc(workspace_path): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - - home = os.path.expanduser('~') - if not os.path.exists('.bazelrc'): - if os.path.exists(os.path.join(home, '.bazelrc')): - with open('.bazelrc', 'a') as f: - f.write('import %s/.bazelrc\n' % home.replace('\\', '/')) + bazelrc_path = os.path.join(workspace_path, '.bazelrc') + + data = [] + if os.path.exists(bazelrc_path): + with open(bazelrc_path, 'r') as f: + data = f.read().splitlines() + with open(bazelrc_path, 'w') as f: + for l in data: + if _TF_BAZELRC_FILENAME in l: + continue + f.write('%s\n' % l) + if is_windows(): + tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") else: - open('.bazelrc', 'w').close() - - remove_line_with('.bazelrc', 'tf_configure') - with open('.bazelrc', 'a') as f: - f.write('import %workspace%/.tf_configure.bazelrc\n') + tf_bazelrc_path = _TF_BAZELRC + f.write('import %s\n' % tf_bazelrc_path) def cleanup_makefile(): @@ -271,7 +261,8 @@ def cleanup_makefile(): These files could interfere with Bazel parsing. """ - makefile_download_dir = 'tensorflow/contrib/makefile/downloads' + makefile_download_dir = os.path.join( + _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -456,7 +447,7 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', 'version']) + curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -492,6 +483,8 @@ def set_cc_opt_flags(environ_cp): if is_ppc64le(): # gcc on ppc64le does not support -march, use mcpu instead default_cc_opt_flags = '-mcpu=native' + elif is_windows(): + default_cc_opt_flags = '/arch:AVX' else: default_cc_opt_flags = '-march=native' question = ('Please specify optimization flags to use during compilation when' @@ -502,14 +495,14 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - write_to_bazelrc('build:opt --host_copt=-march=native') + if not is_ppc64le() and not is_windows(): + write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') # TODO(mikecase): Remove these default defines once we are able to get # TF Lite targets building without them. write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -531,7 +524,7 @@ def set_tf_cuda_clang(environ_cp): def set_tf_download_clang(environ_cp): """Set TF_DOWNLOAD_CLANG action_env.""" - question = 'Do you want to download a fresh release of clang? (Experimental)' + question = 'Do you wish to download a fresh release of clang? (Experimental)' yes_reply = 'Clang will be downloaded and used to compile tensorflow.' no_reply = 'Clang will not be downloaded.' set_action_env_var( @@ -916,7 +909,7 @@ def set_tf_cudnn_version(environ_cp): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) - tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version) ,1) + tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1) default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' @@ -1055,7 +1048,10 @@ def set_tf_tensorrt_install_path(environ_cp): for lib_file in possible_files: if is_compatible(lib_file, cuda_ver, cudnn_ver): - ver_str = nvinfer_pattern.search(lib_file).group(1) + matches = nvinfer_pattern.search(lib_file) + if len(matches.groups()) == 0: + continue + ver_str = matches.group(1) ver = convert_version_to_int(ver_str) if len(ver_str) else 0 if ver > highest_ver[0]: highest_ver = [ver, ver_str, lib_file] @@ -1078,7 +1074,7 @@ def set_tf_tensorrt_install_path(environ_cp): break # Reset and Retry - if len(possible_files): + if possible_files: print('TensorRT libraries found in one the following directories', 'are not compatible with selected cuda and cudnn installations') print(trt_install_path) @@ -1087,7 +1083,8 @@ def set_tf_tensorrt_install_path(environ_cp): if search_result: print(libnvinfer_path_from_ldconfig) else: - print('Invalid path to TensorRT. None of the following files can be found:') + print( + 'Invalid path to TensorRT. None of the following files can be found:') print(trt_install_path) print(os.path.join(trt_install_path, 'lib')) print(os.path.join(trt_install_path, 'lib64')) @@ -1106,6 +1103,81 @@ def set_tf_tensorrt_install_path(environ_cp): write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) +def set_tf_nccl_install_path(environ_cp): + """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION. + + Args: + environ_cp: copy of the os.environ. + + Raises: + ValueError: if this method was called under non-Linux platform. + UserInputError: if user has provided invalid input multiple times. + """ + if not is_linux(): + raise ValueError('Currently NCCL is only supported on Linux platforms.') + + ask_nccl_version = ( + 'Please specify the NCCL version you want to use. ' + '[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION + + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + tf_nccl_version = get_from_env_or_user_or_default( + environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION) + tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1) + + if tf_nccl_version == '1': + break # No need to get install path, NCCL 1 is a GitHub repo. + + # TODO(csigg): Look with ldconfig first if we can find the library in paths + # like /usr/lib/x86_64-linux-gnu and the header file in the corresponding + # include directory. This is where the NCCL .deb packages install them. + # Then ask the user if we should use that. Instead of a single + # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to + # nccl_configure.bzl + default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH') + ask_nccl_path = (r'Please specify the location where NCCL %s library is ' + 'installed. Refer to README.md for more details. [Default ' + 'is %s]:') % (tf_nccl_version, default_nccl_path) + nccl_install_path = get_from_env_or_user_or_default( + environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path)) + if is_windows() or is_cygwin(): + nccl_install_path = cygpath(nccl_install_path) + + if is_windows(): + nccl_lib_path = 'lib/x64/nccl.lib' + elif is_linux(): + nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version + elif is_macos(): + nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version + + nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) + nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') + if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + # Set NCCL_INSTALL_PATH + environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path + write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) + break + + # Reset and Retry + print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the ' + 'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path, + nccl_hdr_path)) + + environ_cp['TF_NCCL_VERSION'] = '' + else: + raise UserInputError('Invalid TF_NCCL setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + + # Set TF_NCCL_VERSION + environ_cp['TF_NCCL_VERSION'] = tf_nccl_version + write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version) + + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -1228,7 +1300,7 @@ def set_host_c_compiler(environ_cp): environ_cp, var_name='HOST_C_COMPILER', var_default=default_c_host_compiler, - ask_for_var=('Please specify which C compiler should be used as the host' + ask_for_var=('Please specify which C compiler should be used as the host ' 'C compiler.'), check_success=os.path.exists, error_msg='Invalid C compiler path. %s cannot be found.', @@ -1372,13 +1444,20 @@ def config_info_line(name, help_text): def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", + type=str, + default=_TF_WORKSPACE_ROOT, + help="The absolute path to your active Bazel workspace.") + args = parser.parse_args() + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) - check_bazel_version('0.5.4') + check_bazel_version('0.10.0') - reset_tf_configure_bazelrc() + reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() setup_python(environ_cp) @@ -1393,6 +1472,9 @@ def main(): environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' environ_cp['TF_NEED_TENSORRT'] = '0' + # TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on + # Windows. + environ_cp['TF_DOWNLOAD_CLANG'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1407,7 +1489,7 @@ def main(): set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', - 'with_kafka_support', False, 'kafka') + 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1432,22 +1514,18 @@ def main(): set_tf_cudnn_version(environ_cp) if is_linux(): set_tf_tensorrt_install_path(environ_cp) + set_tf_nccl_install_path(environ_cp) + set_tf_cuda_compute_capabilities(environ_cp) - if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get('LD_LIBRARY_PATH') != '1': - write_action_env_to_bazelrc('LD_LIBRARY_PATH', environ_cp.get('LD_LIBRARY_PATH')) + if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( + 'LD_LIBRARY_PATH') != '1': + write_action_env_to_bazelrc('LD_LIBRARY_PATH', + environ_cp.get('LD_LIBRARY_PATH')) set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': - if not is_windows(): - # Ask if we want to download clang release while building. - set_tf_download_clang(environ_cp) - else: - # We use bazel's generated crosstool on Windows and there is no - # way to provide downloaded toolchain for that yet. - # TODO(ibiryukov): Investigate using clang as a cuda compiler on - # Windows. - environ_cp['TF_DOWNLOAD_CLANG'] = '0' - + # Ask whether we should download the clang toolchain. + set_tf_download_clang(environ_cp) if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) @@ -1457,6 +1535,13 @@ def main(): if not is_windows(): set_gcc_host_compiler_path(environ_cp) set_other_cuda_vars(environ_cp) + else: + # CUDA not required. Ask whether we should download the clang toolchain and + # use it for the CPU build. + set_tf_download_clang(environ_cp) + if environ_cp.get('TF_DOWNLOAD_CLANG') == '1': + write_to_bazelrc('build --config=download_clang') + write_to_bazelrc('test --config=download_clang') set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index dc995d231d3e591771f801e28024a76610cdba26..f2ad16fa04f5beb6616c58c28d0f0c460c3e3a17 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -240,6 +240,13 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_kafka_support_windows_override", + define_values = {"with_kafka_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_gcp_support_android_override", define_values = {"with_gcp_support": "true"}, @@ -394,309 +401,6 @@ package_group( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -py_library( - name = "tensorflow_py", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = ["//tensorflow/python"], -) - -filegroup( - name = "all_opensource_files", - data = [ - ":all_files", - "//tensorflow/c:all_files", - "//tensorflow/cc:all_files", - "//tensorflow/cc/saved_model:all_files", - "//tensorflow/cc/saved_model/python:all_files", - "//tensorflow/cc/tools:all_files", - "//tensorflow/compiler/aot:all_files", - "//tensorflow/compiler/aot/tests:all_files", - "//tensorflow/compiler/jit:all_files", - "//tensorflow/compiler/jit/graphcycles:all_files", - "//tensorflow/compiler/jit/kernels:all_files", - "//tensorflow/compiler/jit/legacy_flags:all_files", - "//tensorflow/compiler/jit/ops:all_files", - "//tensorflow/compiler/plugin:all_files", - "//tensorflow/compiler/tests:all_files", - "//tensorflow/compiler/tf2xla:all_files", - "//tensorflow/compiler/tf2xla/cc:all_files", - "//tensorflow/compiler/tf2xla/kernels:all_files", - "//tensorflow/compiler/tf2xla/lib:all_files", - "//tensorflow/compiler/tf2xla/ops:all_files", - "//tensorflow/compiler/xla:all_files", - "//tensorflow/compiler/xla/client:all_files", - "//tensorflow/compiler/xla/client/lib:all_files", - "//tensorflow/compiler/xla/legacy_flags:all_files", - "//tensorflow/compiler/xla/python:all_files", - "//tensorflow/compiler/xla/service:all_files", - "//tensorflow/compiler/xla/service/cpu:all_files", - "//tensorflow/compiler/xla/service/gpu:all_files", - "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files", - "//tensorflow/compiler/xla/service/interpreter:all_files", - "//tensorflow/compiler/xla/service/llvm_ir:all_files", - "//tensorflow/compiler/xla/tests:all_files", - "//tensorflow/compiler/xla/tools:all_files", - "//tensorflow/compiler/xla/tools/parser:all_files", - "//tensorflow/contrib:all_files", - "//tensorflow/contrib/all_reduce:all_files", - "//tensorflow/contrib/android:all_files", - "//tensorflow/contrib/batching:all_files", - "//tensorflow/contrib/bayesflow:all_files", - "//tensorflow/contrib/boosted_trees:all_files", - "//tensorflow/contrib/boosted_trees/estimator_batch:all_files", - "//tensorflow/contrib/boosted_trees/lib:all_files", - "//tensorflow/contrib/boosted_trees/proto:all_files", - "//tensorflow/contrib/boosted_trees/resources:all_files", - "//tensorflow/contrib/cloud:all_files", - "//tensorflow/contrib/cloud/kernels:all_files", - "//tensorflow/contrib/cluster_resolver:all_files", - "//tensorflow/contrib/coder:all_files", - "//tensorflow/contrib/compiler:all_files", - "//tensorflow/contrib/copy_graph:all_files", - "//tensorflow/contrib/crf:all_files", - "//tensorflow/contrib/cudnn_rnn:all_files", - "//tensorflow/contrib/data:all_files", - "//tensorflow/contrib/data/kernels:all_files", - "//tensorflow/contrib/data/python/kernel_tests:all_files", - "//tensorflow/contrib/data/python/ops:all_files", - "//tensorflow/contrib/decision_trees/proto:all_files", - "//tensorflow/contrib/deprecated:all_files", - "//tensorflow/contrib/distributions:all_files", - "//tensorflow/contrib/eager/proto:all_files", - "//tensorflow/contrib/eager/python:all_files", - "//tensorflow/contrib/estimator:all_files", - "//tensorflow/contrib/factorization:all_files", - "//tensorflow/contrib/factorization/examples:all_files", - "//tensorflow/contrib/factorization/kernels:all_files", - "//tensorflow/contrib/feature_column:all_files", - "//tensorflow/contrib/ffmpeg:all_files", - "//tensorflow/contrib/ffmpeg/default:all_files", - "//tensorflow/contrib/framework:all_files", - "//tensorflow/contrib/fused_conv:all_files", - "//tensorflow/contrib/gan:all_files", - "//tensorflow/contrib/gdr:all_files", - "//tensorflow/contrib/graph_editor:all_files", - "//tensorflow/contrib/grid_rnn:all_files", - "//tensorflow/contrib/hooks:all_files", - "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files", - "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", - "//tensorflow/contrib/image:all_files", - "//tensorflow/contrib/input_pipeline:all_files", - "//tensorflow/contrib/input_pipeline/kernels:all_files", - "//tensorflow/contrib/integrate:all_files", - "//tensorflow/contrib/keras:all_files", - "//tensorflow/contrib/kernel_methods:all_files", - "//tensorflow/contrib/kfac:all_files", - "//tensorflow/contrib/kfac/examples:all_files", - "//tensorflow/contrib/kfac/examples/tests:all_files", - "//tensorflow/contrib/kfac/python/kernel_tests:all_files", - "//tensorflow/contrib/kfac/python/ops:all_files", - "//tensorflow/contrib/labeled_tensor:all_files", - "//tensorflow/contrib/layers:all_files", - "//tensorflow/contrib/layers/kernels:all_files", - "//tensorflow/contrib/learn:all_files", - "//tensorflow/contrib/learn/python/learn/datasets:all_files", - "//tensorflow/contrib/legacy_seq2seq:all_files", - "//tensorflow/contrib/libsvm:all_files", - "//tensorflow/contrib/linalg:all_files", - "//tensorflow/contrib/linear_optimizer:all_files", - "//tensorflow/contrib/lite:all_files", - "//tensorflow/contrib/lite/java:all_files", - "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", - "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", - "//tensorflow/contrib/lite/java/src/main/native:all_files", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", - "//tensorflow/contrib/lite/kernels:all_files", - "//tensorflow/contrib/lite/kernels/internal:all_files", - "//tensorflow/contrib/lite/models/smartreply:all_files", - "//tensorflow/contrib/lite/nnapi:all_files", - "//tensorflow/contrib/lite/python:all_files", - "//tensorflow/contrib/lite/schema:all_files", - "//tensorflow/contrib/lite/testing:all_files", - "//tensorflow/contrib/lite/toco:all_files", - "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", - "//tensorflow/contrib/lite/toco/python:all_files", - "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", - "//tensorflow/contrib/lite/toco/tflite:all_files", - "//tensorflow/contrib/lite/tools:all_files", - "//tensorflow/contrib/lookup:all_files", - "//tensorflow/contrib/losses:all_files", - "//tensorflow/contrib/makefile:all_files", - "//tensorflow/contrib/memory_stats:all_files", - "//tensorflow/contrib/meta_graph_transform:all_files", - "//tensorflow/contrib/metrics:all_files", - "//tensorflow/contrib/model_pruning:all_files", - "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", - "//tensorflow/contrib/nccl:all_files", - "//tensorflow/contrib/nearest_neighbor:all_files", - "//tensorflow/contrib/nn:all_files", - "//tensorflow/contrib/opt:all_files", - "//tensorflow/contrib/periodic_resample:all_files", - "//tensorflow/contrib/predictor:all_files", - "//tensorflow/contrib/py2tf:all_files", - "//tensorflow/contrib/py2tf/converters:all_files", - "//tensorflow/contrib/py2tf/impl:all_files", - "//tensorflow/contrib/py2tf/pyct:all_files", - "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", - "//tensorflow/contrib/py2tf/utils:all_files", - "//tensorflow/contrib/quantize:all_files", - "//tensorflow/contrib/receptive_field:all_files", - "//tensorflow/contrib/reduce_slice_ops:all_files", - "//tensorflow/contrib/remote_fused_graph/pylib:all_files", - "//tensorflow/contrib/resampler:all_files", - "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/saved_model:all_files", - "//tensorflow/contrib/saved_model/cc/saved_model:all_files", - "//tensorflow/contrib/seq2seq:all_files", - "//tensorflow/contrib/session_bundle:all_files", - "//tensorflow/contrib/session_bundle/example:all_files", - "//tensorflow/contrib/signal:all_files", - "//tensorflow/contrib/slim:all_files", - "//tensorflow/contrib/slim/python/slim/data:all_files", - "//tensorflow/contrib/slim/python/slim/nets:all_files", - "//tensorflow/contrib/solvers:all_files", - "//tensorflow/contrib/sparsemax:all_files", - "//tensorflow/contrib/specs:all_files", - "//tensorflow/contrib/staging:all_files", - "//tensorflow/contrib/stat_summarizer:all_files", - "//tensorflow/contrib/stateless:all_files", - "//tensorflow/contrib/summary:all_files", - "//tensorflow/contrib/tensor_forest:all_files", - "//tensorflow/contrib/tensor_forest/hybrid:all_files", - "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", - "//tensorflow/contrib/tensor_forest/proto:all_files", - "//tensorflow/contrib/tensorboard:all_files", - "//tensorflow/contrib/tensorboard/db:all_files", - "//tensorflow/contrib/tensorrt:all_files", - "//tensorflow/contrib/testing:all_files", - "//tensorflow/contrib/text:all_files", - "//tensorflow/contrib/tfprof:all_files", - "//tensorflow/contrib/timeseries:all_files", - "//tensorflow/contrib/timeseries/examples:all_files", - "//tensorflow/contrib/timeseries/python/timeseries:all_files", - "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", - "//tensorflow/contrib/tpu:all_files", - "//tensorflow/contrib/tpu/profiler:all_files", - "//tensorflow/contrib/tpu/proto:all_files", - "//tensorflow/contrib/training:all_files", - "//tensorflow/contrib/util:all_files", - "//tensorflow/contrib/verbs:all_files", - "//tensorflow/core:all_files", - "//tensorflow/core/api_def:all_files", - "//tensorflow/core/debug:all_files", - "//tensorflow/core/distributed_runtime:all_files", - "//tensorflow/core/distributed_runtime/rpc:all_files", - "//tensorflow/core/grappler:all_files", - "//tensorflow/core/grappler/clusters:all_files", - "//tensorflow/core/grappler/costs:all_files", - "//tensorflow/core/grappler/inputs:all_files", - "//tensorflow/core/grappler/optimizers:all_files", - "//tensorflow/core/grappler/utils:all_files", - "//tensorflow/core/kernels:all_files", - "//tensorflow/core/kernels/batching_util:all_files", - "//tensorflow/core/kernels/data:all_files", - "//tensorflow/core/kernels/data/sql:all_files", - "//tensorflow/core/kernels/fuzzing:all_files", - "//tensorflow/core/kernels/hexagon:all_files", - "//tensorflow/core/kernels/neon:all_files", - "//tensorflow/core/lib/db:all_files", - "//tensorflow/core/ops/compat:all_files", - "//tensorflow/core/platform/cloud:all_files", - "//tensorflow/core/platform/default/build_config:all_files", - "//tensorflow/core/platform/hadoop:all_files", - "//tensorflow/core/platform/s3:all_files", - "//tensorflow/core/profiler:all_files", - "//tensorflow/core/profiler/internal:all_files", - "//tensorflow/core/profiler/internal/advisor:all_files", - "//tensorflow/core/util/ctc:all_files", - "//tensorflow/core/util/tensor_bundle:all_files", - "//tensorflow/examples/adding_an_op:all_files", - "//tensorflow/examples/android:all_files", - "//tensorflow/examples/benchmark:all_files", - "//tensorflow/examples/get_started/regression:all_files", - "//tensorflow/examples/how_tos/reading_data:all_files", - "//tensorflow/examples/image_retraining:all_files", - "//tensorflow/examples/label_image:all_files", - "//tensorflow/examples/learn:all_files", - "//tensorflow/examples/multibox_detector:all_files", - "//tensorflow/examples/saved_model:all_files", - "//tensorflow/examples/speech_commands:all_files", - "//tensorflow/examples/tutorials/estimators:all_files", - "//tensorflow/examples/tutorials/layers:all_files", - "//tensorflow/examples/tutorials/mnist:all_files", - "//tensorflow/examples/tutorials/monitors:all_files", - "//tensorflow/examples/tutorials/word2vec:all_files", - "//tensorflow/examples/wav_to_spectrogram:all_files", - "//tensorflow/go:all_files", - "//tensorflow/java:all_files", - "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", - "//tensorflow/java/src/main/native:all_files", - "//tensorflow/python:all_files", - "//tensorflow/python/data:all_files", - "//tensorflow/python/data/kernel_tests:all_files", - "//tensorflow/python/data/ops:all_files", - "//tensorflow/python/data/util:all_files", - "//tensorflow/python/debug:all_files", - "//tensorflow/python/eager:all_files", - "//tensorflow/python/estimator:all_files", - "//tensorflow/python/feature_column:all_files", - "//tensorflow/python/keras:all_files", - "//tensorflow/python/kernel_tests:all_files", - "//tensorflow/python/kernel_tests/distributions:all_files", - "//tensorflow/python/kernel_tests/linalg:all_files", - "//tensorflow/python/kernel_tests/random:all_files", - "//tensorflow/python/ops/distributions:all_files", - "//tensorflow/python/ops/linalg:all_files", - "//tensorflow/python/ops/losses:all_files", - "//tensorflow/python/profiler:all_files", - "//tensorflow/python/profiler/internal:all_files", - "//tensorflow/python/saved_model:all_files", - "//tensorflow/python/tools:all_files", - "//tensorflow/tools/api/generator:all_files", - "//tensorflow/tools/api/golden:all_files", - "//tensorflow/tools/api/lib:all_files", - "//tensorflow/tools/api/tests:all_files", - "//tensorflow/tools/benchmark:all_files", - "//tensorflow/tools/build_info:all_files", - "//tensorflow/tools/ci_build/gpu_build:all_files", - "//tensorflow/tools/common:all_files", - "//tensorflow/tools/compatibility:all_files", - "//tensorflow/tools/dist_test/server:all_files", - "//tensorflow/tools/docker:all_files", - "//tensorflow/tools/docker/notebooks:all_files", - "//tensorflow/tools/docs:all_files", - "//tensorflow/tools/git:all_files", - "//tensorflow/tools/graph_transforms:all_files", - "//tensorflow/tools/mlpbtxt:all_files", - "//tensorflow/tools/proto_text:all_files", - "//tensorflow/tools/quantization:all_files", - "//tensorflow/tools/test:all_files", - "//tensorflow/user_ops:all_files", - "//third_party/eigen3:all_files", - "//third_party/fft2d:all_files", - "//third_party/flatbuffers:all_files", - "//third_party/hadoop:all_files", - "//third_party/sycl:all_files", - "//third_party/sycl/sycl:all_files", - ], - visibility = ["//visibility:public"], -) - load( "//third_party/mkl:build_defs.bzl", "if_mkl", @@ -746,11 +450,12 @@ tf_cc_shared_object( linkstatic = 1, visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:core_cpu_impl", "//tensorflow/core:framework_internal_impl", + "//tensorflow/core:gpu_runtime_impl", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", - "//tensorflow/core:core_cpu_impl", "//tensorflow/stream_executor:stream_executor_impl", - "//tensorflow/core:gpu_runtime_impl", ] + tf_additional_binary_deps(), ) @@ -773,7 +478,7 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow/c:exported_symbols.lds", + "$(location //tensorflow/c:exported_symbols.lds)", "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], @@ -782,11 +487,12 @@ tf_cc_shared_object( "-z defs", "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow/c:version_script.lds", + "$(location //tensorflow/c:version_script.lds)", ], }), deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", "//tensorflow/c:exported_symbols.lds", "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", @@ -799,7 +505,7 @@ tf_cc_shared_object( linkopts = select({ "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file - "//tensorflow:tf_exported_symbols.lds", + "$(location //tensorflow:tf_exported_symbols.lds)", ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], @@ -807,7 +513,7 @@ tf_cc_shared_object( "-z defs", "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file - "//tensorflow:tf_version_script.lds", + "$(location //tensorflow:tf_version_script.lds)", ], }), deps = [ @@ -829,3 +535,14 @@ exports_files( "tf_exported_symbols.lds", ], ) + +py_library( + name = "tensorflow_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python", + "//tensorflow/tools/api/generator:python_api", + ], +) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 78ad6aec19f3bbbfcb389012ac1577573b3e4901..c8683e3976c90add3f1f54d8e575c798327e9273 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -20,14 +20,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import # pylint: disable=wildcard-import -from tensorflow.python import * # pylint: disable=redefined-builtin +from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin # pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + del absolute_import del division del print_function diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 9060c58c1395f07eff0ccef7bd430b3402f8c826..8a9301d584775cff3ae315e6fd856b00d1734248 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -12,18 +12,15 @@ load( "tf_custom_op_library", ) -# For platform specific build config -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_kernel_tests_linkstatic", -) - # ----------------------------------------------------------------------------- # Public targets filegroup( name = "headers", - srcs = ["c_api.h"], + srcs = [ + "c_api.h", + "c_api_experimental.h", + ], visibility = ["//tensorflow:__subpackages__"], ) @@ -34,7 +31,13 @@ filegroup( "*.cc", "*.h", ], - exclude = ["*test*"], + exclude = [ + "c_api_experimental.cc", + "c_api_experimental.h", + "python_api.cc", + "python_api.h", + "*test*", + ], ), visibility = ["//visibility:public"], ) @@ -101,6 +104,29 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "c_api_experimental", + srcs = [ + "c_api_experimental.cc", + ], + hdrs = [ + "c_api_experimental.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api", + ":c_api_internal", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/contrib/tpu:all_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_platform", + "//tensorflow/core:protos_all_cc", + ], +) + exports_files( [ "version_script.lds", @@ -148,7 +174,7 @@ tf_cuda_library( ], deps = [ ":c_api", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + ":c_api_experimental", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", @@ -193,6 +219,27 @@ tf_cuda_cc_test( ], ) +tf_cc_test( + name = "c_api_experimental_test", + size = "small", + srcs = ["c_api_experimental_test.cc"], + data = ["testdata/tf_record"], + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + # We must ensure that the dependencies can be dynamically linked since + # the shared library must be able to use core:framework. + # linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api_experimental", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "c_api_function_test", size = "small", @@ -237,20 +284,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + # TODO(b/74620627): remove when _USE_C_SHAPES is removed + "//tensorflow/python:cpp_shape_inference_proto_cc", ], ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 85f1d1639b4d09f2de77d326481a86ec246270d0..18eeb2816807ec9986999cfc2c9a4c0f032683c0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -30,6 +30,7 @@ limitations under the License. #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eval_const_tensor.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/log_memory.h" @@ -62,6 +63,7 @@ limitations under the License. // brain namespace because we are defining 'extern "C"' functions. using tensorflow::AllocationDescription; using tensorflow::DataType; +using tensorflow::ExtendSessionGraphHelper; using tensorflow::Graph; using tensorflow::GraphDef; using tensorflow::mutex_lock; @@ -73,6 +75,7 @@ using tensorflow::NodeBuilder; using tensorflow::NodeDef; using tensorflow::OpDef; using tensorflow::OpRegistry; +using tensorflow::OutputTensor; using tensorflow::PartialTensorShape; using tensorflow::RunMetadata; using tensorflow::RunOptions; @@ -638,17 +641,17 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, } void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type) - EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + const char* mutation_type) { // If any session has already run this node_id, mark this session as // unrunnable. for (auto it : graph->sessions) { + mutex_lock session_lock(it.first->mu); if (it.first->last_num_graph_nodes > op.node.id()) { - it.second = FailedPrecondition( + it.second = strings::StrCat( "Operation '", op.node.DebugString(), "' was changed by ", mutation_type, - " after it was run by a session. Nodes can be mutated " - "only before they are executed by a session. Either don't modify " + " after it was run by a session. This mutation will have no effect, " + "and will trigger an error in the future. Either don't modify " "nodes after running them or create a new session."); } } @@ -708,6 +711,61 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); +// TODO(josh11b,mrry): Change Session to be able to use a Graph* +// directly, instead of requiring us to serialize to a GraphDef and +// call Session::Extend(). +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { + if (session->graph != nullptr) { + // Take the graph lock before the session lock to avoid deadlock. This is + // safe since session->graph does not change. + session->graph->mu.lock(); + mutex_lock session_lock(session->mu); + const Graph& graph = session->graph->graph; + + const string& mutation_warning = session->graph->sessions[session]; + if (!mutation_warning.empty()) { + // TODO(b/74949947): turn this back into an error status + LOG(WARNING) << mutation_warning; + session->graph->sessions[session].clear(); + } + + const auto num_nodes = graph.num_node_ids(); + if (session->last_num_graph_nodes < num_nodes) { + status->status = tensorflow::ValidateNoCycles(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + + GraphDef graph_def; + *graph_def.mutable_versions() = graph.versions(); + // Fill graph_def with nodes with ids in the range + // [session->last_num_graph_nodes, num_nodes), that is the nodes + // added since the last TF_SessionRun() call. + for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { + Node* const node = graph.FindNodeId(id); + if (node != nullptr && node->IsOp()) { + NodeDef* const node_def = graph_def.add_node(); + *node_def = node->def(); + } + } + *graph_def.mutable_library() = graph.flib_def().ToProto(); + session->graph->mu.unlock(); + status->status = session->session->Extend(graph_def); + if (!status->status.ok()) { + // Contract is we always delete input_values[i]. + return false; + } + // Note: session->session is not modified if Extend() fails, so + // we only set last_num_graph_nodes if it succeeds. + session->last_num_graph_nodes = num_nodes; + } else { + session->graph->mu.unlock(); + } + } + return true; +} + } // namespace tensorflow static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, @@ -2408,11 +2466,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, // TF_Session functions ---------------------------------------------- TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) { - if (s->LocalDeviceManager(&device_mgr).ok()) { - devices = device_mgr->ListDevices(); - } -} + : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {} TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { @@ -2422,7 +2476,7 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->sessions[new_session] = Status::OK(); + graph->sessions[new_session] = ""; } return new_session; } else { @@ -2488,7 +2542,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->sessions[session] = Status::OK(); + graph->sessions[session] = ""; session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2512,58 +2566,6 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { delete s; } -// TODO(josh11b,mrry): Change Session to be able to use a Graph* -// directly, instead of requiring us to serialize to a GraphDef and -// call Session::Extend(). -static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { - if (session->graph != nullptr) { - mutex_lock session_lock(session->mu); - session->graph->mu.lock(); - const Graph& graph = session->graph->graph; - - status->status = session->graph->sessions[session]; - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - const auto num_nodes = graph.num_node_ids(); - if (session->last_num_graph_nodes < num_nodes) { - status->status = tensorflow::ValidateNoCycles(session->graph->graph); - if (!status->status.ok()) { - session->graph->mu.unlock(); - return false; - } - - GraphDef graph_def; - *graph_def.mutable_versions() = graph.versions(); - // Fill graph_def with nodes with ids in the range - // [session->last_num_graph_nodes, num_nodes), that is the nodes - // added since the last TF_SessionRun() call. - for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) { - Node* const node = graph.FindNodeId(id); - if (node != nullptr && node->IsOp()) { - NodeDef* const node_def = graph_def.add_node(); - *node_def = node->def(); - } - } - *graph_def.mutable_library() = graph.flib_def().ToProto(); - session->graph->mu.unlock(); - status->status = session->session->Extend(graph_def); - if (!status->status.ok()) { - // Contract is we always delete input_values[i]. - return false; - } - // Note: session->session is not modified if Extend() fails, so - // we only set last_num_graph_nodes if it succeeds. - session->last_num_graph_nodes = num_nodes; - } else { - session->graph->mu.unlock(); - } - } - return true; -} - void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, const TF_Output* outputs, @@ -2573,7 +2575,8 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2610,7 +2613,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, const char** handle, TF_Status* status) { *handle = nullptr; - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2653,7 +2657,8 @@ void TF_SessionPRun(TF_Session* session, const char* handle, // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and // call Session::Extend(). - if (!ExtendSessionGraphHelper(session, status)) { + if (session->extend_before_run && + !ExtendSessionGraphHelper(session, status)) { return; } @@ -2682,6 +2687,24 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output, + TF_Tensor** result, TF_Status* status) { + *result = nullptr; + mutex_lock l(graph->mu); + OutputTensor tensor(&output.oper->node, output.index); + bool evaluated; + Tensor result_tensor; + status->status = EvaluateConstantTensor( + tensor, graph->refiner, *graph->graph.op_registry(), + graph->graph.versions().producer(), &evaluated, &result_tensor); + if (evaluated) { + DCHECK(status->status.ok()); + *result = TF_TensorFromTensor(result_tensor, status); + if (!status->status.ok()) evaluated = false; + } + return evaluated; +} + TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { tensorflow::OpList op_list; if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index ad592ef70961ef427bfe9fd322a82bd64df7f9f1..c8594347451dffd465d7fa926cc53818dc9e38d4 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -72,7 +72,7 @@ limitations under the License. #ifdef SWIG #define TF_CAPI_EXPORT #else -#if defined(COMPILER_MSVC) +#if defined(_WIN32) #ifdef TF_COMPILE_LIBRARY #define TF_CAPI_EXPORT __declspec(dllexport) #else @@ -80,7 +80,7 @@ limitations under the License. #endif // TF_COMPILE_LIBRARY #else #define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // COMPILER_MSVC +#endif // _WIN32 #endif // SWIG #ifdef __cplusplus @@ -1275,13 +1275,22 @@ TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( // Deleting a function does not remove it from any graphs it was copied to. TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); +// Attempts to evaluate `output`. This will only be possible if `output` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, this function returns true and `output`s +// value is returned in `result`. Otherwise returns false. An error status is +// returned if something is wrong with the graph or input. Note that this may +// return false even if no error status is set. +TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, + TF_Output output, + TF_Tensor** result, + TF_Status* status); + // TODO(josh11b): Register OpDef, available to all operations added // to this graph. -// The following two may both benefit from a subgraph-definition API -// that re-uses most of the graph-definition API. -// TODO(andydavis): Add functions to a graph. - // -------------------------------------------------------------------------- // API for driving Graph execution. @@ -1487,7 +1496,8 @@ TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieves the type of the device at the given index. // @@ -1497,14 +1507,15 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, // If index is out of bounds, an error code will be set in the status object, // and a null pointer will be returned. TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, - int index, TF_Status*); + int index, + TF_Status* status); // Retrieve the amount of memory associated with a given device. // // If index is out of bounds, an error code will be set in the status object, // and -1 will be returned. TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( - const TF_DeviceList* list, int index, TF_Status*); + const TF_DeviceList* list, int index, TF_Status* status); // -------------------------------------------------------------------------- // Load plugins containing custom ops and kernels diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..d3916bc16778a942b7eab4df93bbc19955b19e31 --- /dev/null +++ b/tensorflow/c/c_api_experimental.cc @@ -0,0 +1,8370 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/protobuf/config.pb.h" + +using tensorflow::FunctionDef; +using tensorflow::Node; +using tensorflow::NodeBuilder; +using tensorflow::Status; + +namespace { +typedef std::unique_ptr + UniqueFuncPtr; +} + +// struct TF_Operation { tensorflow::Node node; }; +static TF_Operation* ToTF_Operation(Node* node) { + return static_cast(static_cast(node)); +} + +void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { + tensorflow::ConfigProto& config = options->options.config; + auto* optimizer_options = + config.mutable_graph_options()->mutable_optimizer_options(); + if (enable) { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + + // These XLA flags are needed to trigger XLA properly from C (more generally + // non-Python) clients. If this API is called again with `enable` set to + // false, it is safe to keep these flag values as is. + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = true; + flags->tf_xla_min_cluster_size = 1; + } else { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); + } +} + +const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) { + tensorflow::mutex_lock c(graph->mu); + const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString(); + *len = debug_str.size(); + char* ret = static_cast(malloc(*len + 1)); + memcpy(ret, debug_str.c_str(), *len + 1); + return ret; +} + +// On success, returns a set of TF_Function instances from `text_proto` of +// GraphDef type. These functions must be deleted by calling TF_DeleteFunction. +// +// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto, +// before creating a TF_Function out of the possibly mutated proto. +static std::vector CreateFunctionsFromTextProto( + const char* text_proto, + std::function* mutate_proto_func, TF_Status* status) { + tensorflow::GraphDef gdef; + if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) { + status->status = tensorflow::errors::Internal( + "Invalid text proto for GraphDef: ", text_proto); + return {}; + } + const auto& fdef_lib = gdef.library(); + if (fdef_lib.gradient_size() > 0) { + status->status = tensorflow::errors::Internal( + "GradientDef is not supported in reading Dataset related functions: ", + text_proto); + return {}; + } + std::vector ret; + for (const FunctionDef& fdef : fdef_lib.function()) { + // Make a copy so that we can mutate it. + FunctionDef fdef_to_load = fdef; + if (mutate_proto_func) { + (*mutate_proto_func)(&fdef_to_load); + } + VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString(); + std::vector binary_proto_buf(fdef_to_load.ByteSizeLong()); + fdef_to_load.SerializeToArray(binary_proto_buf.data(), + binary_proto_buf.size()); + TF_Function* func = TF_FunctionImportFunctionDef( + binary_proto_buf.data(), binary_proto_buf.size(), status); + if (!status->status.ok()) return {}; + ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction)); + } + return ret; +} + +// On success, returns a newly created TF_Function instance encoding a dataset +// node stack that returns a sequence of 3 floats, and sets `dataset_name` to +// the created dataset name. The returned function must be deleted by calling +// TF_DeleteFunction. +static UniqueFuncPtr CreateFakeDatasetFunction(std::string* dataset_name, + TF_Status* status) { + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "_make_dataset_d8de2712" + output_arg { + name: "TensorSliceDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/tensors/component_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\000\000(B\000\000,B\000\0000B" + } + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/tensors/component_0:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + ret { + key: "TensorSliceDataset" + value: "TensorSliceDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_d8de2712"; + auto functions = CreateFunctionsFromTextProto( + func_def, /*mutate_proto_func*/ nullptr, status); + DCHECK_EQ(functions.size(), 1); + return std::move(functions[0]); +} + +#if not defined(PLATFORM_WINDOWS) +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateImagenetDatasetFunctions( + const char* file_path, std::string* dataset_name, TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return std::vector(); +#else + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_91295dea" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "FlatMapDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "flat_filenames/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: -1 + } + } + } + } + node_def { + name: "flat_filenames" + op: "Reshape" + input: "arg0" + input: "flat_filenames/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "flat_filenames:output:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "FlatMapDataset" + op: "FlatMapDataset" + input: "TensorSliceDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_0cc8c35b" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + ret { + key: "FlatMapDataset" + value: "FlatMapDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_0cc8c35b" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "TFRecordDataset" + type: DT_VARIANT + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "compression_type" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8388608 + } + } + } + } + node_def { + name: "TFRecordDataset" + op: "TFRecordDataset" + input: "arg0" + input: "compression_type:output:0" + input: "buffer_size:output:0" + } + ret { + key: "TFRecordDataset" + value: "TFRecordDataset:handle:0" + } + } + function { + signature { + name: "tf_map_func_74b6b15c" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "Reshape_1" + type: DT_FLOAT + } + output_arg { + name: "sub_1" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + is_stateful: true + } + node_def { + name: "ParseSingleExample/key_image/class/label" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape" + op: "Reshape" + input: "ParseSingleExample/key_image/class/label:output:0" + input: "ParseSingleExample/Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/class/text" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_1" + op: "Reshape" + input: "ParseSingleExample/key_image/class/text:output:0" + input: "ParseSingleExample/Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/encoded" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_2" + op: "Reshape" + input: "ParseSingleExample/key_image/encoded:output:0" + input: "ParseSingleExample/Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/key_image/format" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "jpeg" + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "ParseSingleExample/Reshape_3" + op: "Reshape" + input: "ParseSingleExample/key_image/format:output:0" + input: "ParseSingleExample/Reshape_3/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ParseSingleExample/ParseSingleExample" + op: "ParseSingleExample" + input: "arg0" + input: "ParseSingleExample/Reshape:output:0" + input: "ParseSingleExample/Reshape_1:output:0" + input: "ParseSingleExample/Reshape_2:output:0" + input: "ParseSingleExample/Reshape_3:output:0" + attr { + key: "Tdense" + value { + list { + type: DT_INT64 + type: DT_STRING + type: DT_STRING + type: DT_STRING + } + } + } + attr { + key: "dense_keys" + value { + list { + s: "image/class/label" + s: "image/class/text" + s: "image/encoded" + s: "image/format" + } + } + } + attr { + key: "dense_shapes" + value { + list { + shape { + } + shape { + } + shape { + } + shape { + } + } + } + } + attr { + key: "num_sparse" + value { + i: 5 + } + } + attr { + key: "sparse_keys" + value { + list { + s: "image/object/bbox/xmax" + s: "image/object/bbox/xmin" + s: "image/object/bbox/ymax" + s: "image/object/bbox/ymin" + s: "image/object/class/label" + } + } + } + attr { + key: "sparse_types" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_FLOAT + type: DT_INT64 + } + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:2" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/Substr/pos:output:0" + input: "decode_image/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/pos" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr/len" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Substr" + op: "Substr" + input: "Reshape:output:0" + input: "decode_image/is_jpeg/Substr/pos:output:0" + input: "decode_image/is_jpeg/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\377\330\377" + } + } + } + } + node_def { + name: "decode_image/is_jpeg/Equal" + op: "Equal" + input: "decode_image/is_jpeg/Substr:output:0" + input: "decode_image/is_jpeg/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/Switch" + op: "Switch" + input: "decode_image/is_jpeg/Equal:z:0" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/pred_id" + op: "Identity" + input: "decode_image/is_jpeg/Equal:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/check_jpeg_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/check_jpeg_channels/x:output:0" + input: "decode_image/cond_jpeg/check_jpeg_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 1, 3) when decoding JPEG images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/check_jpeg_channels:z:0" + input: "decode_image/cond_jpeg/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg" + op: "DecodeJpeg" + input: "decode_image/cond_jpeg/DecodeJpeg/Switch:output_true:0" + input: "^decode_image/cond_jpeg/Assert/Assert" + attr { + key: "acceptable_fraction" + value { + f: 1.0 + } + } + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dct_method" + value { + s: "" + } + } + attr { + key: "fancy_upscaling" + value { + b: true + } + } + attr { + key: "ratio" + value { + i: 1 + } + } + attr { + key: "try_recover_truncated" + value { + b: false + } + } + } + node_def { + name: "decode_image/cond_jpeg/DecodeJpeg/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/y" + op: "Const" + input: "^decode_image/cond_jpeg/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\211PN" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png" + op: "Equal" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/is_png/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/is_png/Switch" + op: "Switch" + input: "decode_image/Substr:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png:z:0" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/is_png:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng" + op: "DecodePng" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1:output_true:0" + attr { + key: "channels" + value { + i: 3 + } + } + attr { + key: "dtype" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch" + op: "Switch" + input: "Reshape:output:0" + input: "decode_image/cond_jpeg/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/DecodePng/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "GIF" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/is_gif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/is_gif/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/is_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/is_png/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@decode_image/Substr" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id" + op: "Identity" + input: "decode_image/cond_jpeg/cond_png/is_gif:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 4 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd" + op: "LogicalAnd" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_gif_channels_1:z:0" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding GIF images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/LogicalAnd:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif" + op: "DecodeGif" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1:output_true:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert/Assert" + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/DecodePng/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch_1" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr" + op: "Substr" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/pos:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/len:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch" + op: "Switch" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif/Switch:output_false:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/pred_id:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Reshape" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "BM" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp" + op: "Equal" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp/y:output:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Unable to decode bytes as JPEG, PNG, GIF, or BMP" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/is_bmp:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels" + op: "NotEqual" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/x:output:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Const" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0" + op: "Const" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/switch_f" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Channels must be in (None, 0, 3) when decoding BMP images" + } + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + op: "Assert" + input: "decode_image/cond_jpeg/cond_png/cond_gif/check_channels:z:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp" + op: "DecodeBmp" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Substr/Switch:output_false:0" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_1/Assert" + input: "^decode_image/cond_jpeg/cond_png/cond_gif/Assert_2/Assert" + attr { + key: "channels" + value { + i: 0 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/cond_gif/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeBmp:image:0" + input: "decode_image/cond_jpeg/cond_png/cond_gif/DecodeGif:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/cond_png/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/cond_gif/Merge:output:0" + input: "decode_image/cond_jpeg/cond_png/DecodePng:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "decode_image/cond_jpeg/Merge" + op: "Merge" + input: "decode_image/cond_jpeg/cond_png/Merge:output:0" + input: "decode_image/cond_jpeg/DecodeJpeg:image:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/Cast" + op: "Cast" + input: "decode_image/cond_jpeg/Merge:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "convert_image/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.00392156885937 + } + } + } + } + node_def { + name: "convert_image" + op: "Mul" + input: "convert_image/Cast:y:0" + input: "convert_image/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 4 + } + } + tensor_content: "\000\000\000\000\000\000\000\000\000\000\200?\000\000\200?" + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.10000000149 + } + } + } + } + node_def { + name: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2" + op: "SampleDistortedBoundingBoxV2" + input: "distorted_bounding_box_crop/Shape:output:0" + input: "Const:output:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2/min_object_covered:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "area_range" + value { + list { + f: 0.0799999982119 + f: 1.0 + } + } + } + attr { + key: "aspect_ratio_range" + value { + list { + f: 0.75 + f: 1.33333337307 + } + } + } + attr { + key: "max_attempts" + value { + i: 1 + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + attr { + key: "use_image_if_no_bounding_boxes" + value { + b: true + } + } + } + node_def { + name: "distorted_bounding_box_crop/Slice" + op: "Slice" + input: "convert_image:z:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:begin:0" + input: "distorted_bounding_box_crop/sample_distorted_bounding_box/SampleDistortedBoundingBoxV2:size:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Shape" + op: "Shape" + input: "convert_image:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Shape_1" + op: "Shape" + input: "distorted_bounding_box_crop/Slice:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "Shape:output:0" + input: "Shape_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "Equal:z:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_BOOL + } + } + } + node_def { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "Sum" + op: "Sum" + input: "Cast:y:0" + input: "Const_1:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node_def { + name: "GreaterEqual/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "GreaterEqual" + op: "GreaterEqual" + input: "Sum:output:0" + input: "GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Switch" + op: "Switch" + input: "GreaterEqual:z:0" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_t" + op: "Identity" + input: "cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/switch_f" + op: "Identity" + input: "cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/pred_id" + op: "Identity" + input: "GreaterEqual:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/Shape" + op: "Shape" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Shape/Switch" + op: "Switch" + input: "convert_image:z:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@convert_image" + } + } + } + } + node_def { + name: "cond/Cast" + op: "Cast" + input: "cond/Shape:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice/stack:output:0" + input: "cond/strided_slice/stack_1:output:0" + input: "cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/strided_slice_1/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_1" + op: "StridedSlice" + input: "cond/Cast:y:0" + input: "cond/strided_slice_1/stack:output:0" + input: "cond/strided_slice_1/stack_1:output:0" + input: "cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Greater" + op: "Greater" + input: "cond/strided_slice:output:0" + input: "cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Switch" + op: "Switch" + input: "cond/Greater:z:0" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_t" + op: "Identity" + input: "cond/cond/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/switch_f" + op: "Identity" + input: "cond/cond/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/pred_id" + op: "Identity" + input: "cond/Greater:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "cond/cond/strided_slice/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice/stack:output:0" + input: "cond/cond/strided_slice/stack_1:output:0" + input: "cond/cond/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1/stack_2" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_1" + op: "StridedSlice" + input: "cond/cond/strided_slice/Switch:output_true:0" + input: "cond/cond/strided_slice_1/stack:output:0" + input: "cond/cond/strided_slice_1/stack_1:output:0" + input: "cond/cond/strided_slice_1/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv" + op: "RealDiv" + input: "cond/cond/strided_slice:output:0" + input: "cond/cond/strided_slice_1:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul/y" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul" + op: "Mul" + input: "cond/cond/truediv:z:0" + input: "cond/cond/mul/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast/x/1" + op: "Const" + input: "^cond/cond/switch_t" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast/x" + op: "Pack" + input: "cond/cond/mul:z:0" + input: "cond/cond/Cast/x/1:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast" + op: "Cast" + input: "cond/cond/Cast/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_2" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_2/stack:output:0" + input: "cond/cond/strided_slice_2/stack_1:output:0" + input: "cond/cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/strided_slice_2/Switch" + op: "Switch" + input: "cond/Cast:y:0" + input: "cond/cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@cond/Cast" + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/cond/strided_slice_3" + op: "StridedSlice" + input: "cond/cond/strided_slice_2/Switch:output_false:0" + input: "cond/cond/strided_slice_3/stack:output:0" + input: "cond/cond/strided_slice_3/stack_1:output:0" + input: "cond/cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/cond/truediv_1" + op: "RealDiv" + input: "cond/cond/strided_slice_2:output:0" + input: "cond/cond/strided_slice_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/mul_1/y" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/mul_1" + op: "Mul" + input: "cond/cond/truediv_1:z:0" + input: "cond/cond/mul_1/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Cast_1/x/0" + op: "Const" + input: "^cond/cond/switch_f" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 224.0 + } + } + } + } + node_def { + name: "cond/cond/Cast_1/x" + op: "Pack" + input: "cond/cond/Cast_1/x/0:output:0" + input: "cond/cond/mul_1:z:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/cond/Cast_1" + op: "Cast" + input: "cond/cond/Cast_1/x:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/cond/Merge" + op: "Merge" + input: "cond/cond/Cast_1:y:0" + input: "cond/cond/Cast:y:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic/images" + op: "Pack" + input: "cond/Shape/Switch:output_true:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic" + op: "ResizeBicubic" + input: "cond/ResizeBicubic/images:output:0" + input: "cond/cond/Merge:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_2/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_2" + op: "StridedSlice" + input: "cond/ResizeBicubic:resized_images:0" + input: "cond/strided_slice_2/stack:output:0" + input: "cond/strided_slice_2/stack_1:output:0" + input: "cond/strided_slice_2/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_1" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_3/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_3" + op: "StridedSlice" + input: "cond/Shape_1:output:0" + input: "cond/strided_slice_3/stack:output:0" + input: "cond/strided_slice_3/stack_1:output:0" + input: "cond/strided_slice_3/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Shape_2" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_4/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_4/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_4" + op: "StridedSlice" + input: "cond/Shape_2:output:0" + input: "cond/strided_slice_4/stack:output:0" + input: "cond/strided_slice_4/stack_1:output:0" + input: "cond/strided_slice_4/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/sub/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub" + op: "Sub" + input: "cond/strided_slice_3:output:0" + input: "cond/sub/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add" + op: "Add" + input: "cond/sub:z:0" + input: "cond/add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv/Cast" + op: "Cast" + input: "cond/add:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv/Cast_1" + op: "Cast" + input: "cond/truediv/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv" + op: "RealDiv" + input: "cond/truediv/Cast:y:0" + input: "cond/truediv/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/sub_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/sub_1" + op: "Sub" + input: "cond/strided_slice_4:output:0" + input: "cond/sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/add_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/add_1" + op: "Add" + input: "cond/sub_1:z:0" + input: "cond/add_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/truediv_1/Cast" + op: "Cast" + input: "cond/add_1:z:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1/Cast_1" + op: "Cast" + input: "cond/truediv_1/y:output:0" + attr { + key: "DstT" + value { + type: DT_DOUBLE + } + } + attr { + key: "SrcT" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/truediv_1" + op: "RealDiv" + input: "cond/truediv_1/Cast:y:0" + input: "cond/truediv_1/Cast_1:y:0" + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Shape_3" + op: "Shape" + input: "cond/strided_slice_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Rank" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/Equal" + op: "Equal" + input: "cond/Rank:output:0" + input: "cond/Equal/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/Assert/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Rank of image must be equal to 3." + } + } + } + } + node_def { + name: "cond/Assert/Assert" + op: "Assert" + input: "cond/Equal:z:0" + input: "cond/Assert/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/strided_slice_5/stack" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 3 + } + } + } + } + node_def { + name: "cond/strided_slice_5/stack_2" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_5" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_5/stack:output:0" + input: "cond/strided_slice_5/stack_1:output:0" + input: "cond/strided_slice_5/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/stack/0" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack/1" + op: "Const" + input: "^cond/Assert/Assert" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/stack" + op: "Pack" + input: "cond/stack/0:output:0" + input: "cond/stack/1:output:0" + input: "cond/strided_slice_5:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/strided_slice_6/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_6" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_6/stack:output:0" + input: "cond/strided_slice_6/stack_1:output:0" + input: "cond/strided_slice_6/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual" + op: "GreaterEqual" + input: "cond/strided_slice_6:output:0" + input: "cond/GreaterEqual/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/strided_slice_7/stack" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_1" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 2 + } + } + } + } + node_def { + name: "cond/strided_slice_7/stack_2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_7" + op: "StridedSlice" + input: "cond/Shape_3:output:0" + input: "cond/strided_slice_7/stack:output:0" + input: "cond/strided_slice_7/stack_1:output:0" + input: "cond/strided_slice_7/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/GreaterEqual_1/y" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 224 + } + } + } + } + node_def { + name: "cond/GreaterEqual_1" + op: "GreaterEqual" + input: "cond/strided_slice_7:output:0" + input: "cond/GreaterEqual_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/LogicalAnd" + op: "LogicalAnd" + input: "cond/GreaterEqual:z:0" + input: "cond/GreaterEqual_1:z:0" + } + node_def { + name: "cond/Assert_1/Const" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert/data_0" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "Crop size greater than the image size." + } + } + } + } + node_def { + name: "cond/Assert_1/Assert" + op: "Assert" + input: "cond/LogicalAnd:z:0" + input: "cond/Assert_1/Assert/data_0:output:0" + attr { + key: "T" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "summarize" + value { + i: 3 + } + } + } + node_def { + name: "cond/stack_1/2" + op: "Const" + input: "^cond/switch_t" + attr { + key: "dtype" + value { + type: DT_DOUBLE + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_DOUBLE + tensor_shape { + } + double_val: 0.0 + } + } + } + } + node_def { + name: "cond/stack_1" + op: "Pack" + input: "cond/truediv:z:0" + input: "cond/truediv_1:z:0" + input: "cond/stack_1/2:output:0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_DOUBLE + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ToInt32" + op: "Cast" + input: "cond/stack_1:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_DOUBLE + } + } + } + node_def { + name: "cond/Slice" + op: "Slice" + input: "cond/strided_slice_2:output:0" + input: "cond/ToInt32:y:0" + input: "cond/stack:output:0" + input: "^cond/Assert_1/Assert" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "cond/Reshape" + op: "Reshape" + input: "cond/Slice:output:0" + input: "cond/stack:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images" + op: "Pack" + input: "cond/ResizeBicubic_1/images/Switch:output_false:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "axis" + value { + i: 0 + } + } + } + node_def { + name: "cond/ResizeBicubic_1/images/Switch" + op: "Switch" + input: "distorted_bounding_box_crop/Slice:output:0" + input: "cond/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@distorted_bounding_box_crop/Slice" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1/size" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\340\000\000\000\340\000\000\000" + } + } + } + } + node_def { + name: "cond/ResizeBicubic_1" + op: "ResizeBicubic" + input: "cond/ResizeBicubic_1/images:output:0" + input: "cond/ResizeBicubic_1/size:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } + } + node_def { + name: "cond/strided_slice_8/stack" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_1" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8/stack_2" + op: "Const" + input: "^cond/switch_f" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "cond/strided_slice_8" + op: "StridedSlice" + input: "cond/ResizeBicubic_1:resized_images:0" + input: "cond/strided_slice_8/stack:output:0" + input: "cond/strided_slice_8/stack_1:output:0" + input: "cond/strided_slice_8/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "cond/Merge" + op: "Merge" + input: "cond/strided_slice_8:output:0" + input: "cond/Reshape:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\354Q\370>\325x\351>;\337\317>" + } + } + } + } + node_def { + name: "sub" + op: "Sub" + input: "cond/Merge:output:0" + input: "Const_2:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Const_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 3 + } + } + tensor_content: "\372~j>B`e>fff>" + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "sub:z:0" + input: "Const_3:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/control_dependency" + op: "Identity" + input: "truediv:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/min" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/max" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1.0 + } + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/RandomUniform" + op: "RandomUniform" + input: "random_flip_left_right/random_uniform/shape:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "seed" + value { + i: 0 + } + } + attr { + key: "seed2" + value { + i: 0 + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/sub" + op: "Sub" + input: "random_flip_left_right/random_uniform/max:output:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform/mul" + op: "Mul" + input: "random_flip_left_right/random_uniform/RandomUniform:output:0" + input: "random_flip_left_right/random_uniform/sub:z:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/random_uniform" + op: "Add" + input: "random_flip_left_right/random_uniform/mul:z:0" + input: "random_flip_left_right/random_uniform/min:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Less/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.5 + } + } + } + } + node_def { + name: "random_flip_left_right/Less" + op: "Less" + input: "random_flip_left_right/random_uniform:z:0" + input: "random_flip_left_right/Less/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "random_flip_left_right/Switch" + op: "Switch" + input: "random_flip_left_right/Less:z:0" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_t" + op: "Identity" + input: "random_flip_left_right/Switch:output_true:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/switch_f" + op: "Identity" + input: "random_flip_left_right/Switch:output_false:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/pred_id" + op: "Identity" + input: "random_flip_left_right/Less:z:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/axis" + op: "Const" + input: "^random_flip_left_right/switch_t" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2" + op: "ReverseV2" + input: "random_flip_left_right/ReverseV2/Switch:output_true:0" + input: "random_flip_left_right/ReverseV2/axis:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + node_def { + name: "random_flip_left_right/ReverseV2/Switch" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Switch_1" + op: "Switch" + input: "random_flip_left_right/control_dependency:output:0" + input: "random_flip_left_right/pred_id:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@truediv" + } + } + } + } + node_def { + name: "random_flip_left_right/Merge" + op: "Merge" + input: "random_flip_left_right/Switch_1:output_false:0" + input: "random_flip_left_right/ReverseV2:output:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Reshape_1/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 3 + } + } + tensor_content: "\340\000\000\000\340\000\000\000\003\000\000\000" + } + } + } + } + node_def { + name: "Reshape_1" + op: "Reshape" + input: "random_flip_left_right/Merge:output:0" + input: "Reshape_1/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Reshape_2/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape_2" + op: "Reshape" + input: "ParseSingleExample/ParseSingleExample:dense_values:0" + input: "Reshape_2/shape:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "Cast_1" + op: "Cast" + input: "Reshape_2:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_INT64 + } + } + } + node_def { + name: "sub_1/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "sub_1" + op: "Sub" + input: "Cast_1:y:0" + input: "sub_1/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + ret { + key: "Reshape_1" + value: "Reshape_1:output:0" + } + ret { + key: "sub_1" + value: "sub_1:z:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_5fa5e1f4" + output_arg { + name: "PrefetchDataset_1" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "TensorSliceDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "TensorSliceDataset/MatchingFiles" + op: "MatchingFiles" + input: "TensorSliceDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "TensorSliceDataset" + op: "TensorSliceDataset" + input: "TensorSliceDataset/MatchingFiles:filenames:0" + attr { + key: "Toutput_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles/pattern" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)" + } + } + } + } + node_def { + name: "ShuffleDataset/MatchingFiles" + op: "MatchingFiles" + input: "ShuffleDataset/MatchingFiles/pattern:output:0" + } + node_def { + name: "ShuffleDataset/Shape" + op: "Shape" + input: "ShuffleDataset/MatchingFiles:filenames:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/strided_slice" + op: "StridedSlice" + input: "ShuffleDataset/Shape:output:0" + input: "ShuffleDataset/strided_slice/stack:output:0" + input: "ShuffleDataset/strided_slice/stack_1:output:0" + input: "ShuffleDataset/strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "ShuffleDataset/Maximum/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ShuffleDataset/Maximum" + op: "Maximum" + input: "ShuffleDataset/strided_slice:output:0" + input: "ShuffleDataset/Maximum/y:output:0" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "TensorSliceDataset:handle:0" + input: "ShuffleDataset/Maximum:z:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ShuffleDataset_1/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1/seed2_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_1" + op: "ShuffleDataset" + input: "ShuffleDataset:handle:0" + input: "ShuffleDataset_1/buffer_size:output:0" + input: "ShuffleDataset_1/seed_1:output:0" + input: "ShuffleDataset_1/seed2_1:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "ShuffleDataset_1:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/cycle_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/block_length" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/sloppy" + op: "Const" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: true + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/buffer_output_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset/prefetch_input_elements" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "ParallelInterleaveDataset" + op: "ParallelInterleaveDataset" + input: "RepeatDataset:handle:0" + input: "ParallelInterleaveDataset/cycle_length:output:0" + input: "ParallelInterleaveDataset/block_length:output:0" + input: "ParallelInterleaveDataset/sloppy:output:0" + input: "ParallelInterleaveDataset/buffer_output_elements:output:0" + input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_91295dea" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + } + node_def { + name: "ShuffleDataset_2/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1024 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2/seed2_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset_2" + op: "ShuffleDataset" + input: "ParallelInterleaveDataset:handle:0" + input: "ShuffleDataset_2/buffer_size_1:output:0" + input: "ShuffleDataset_2/seed_2:output:0" + input: "ShuffleDataset_2/seed2_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "ParallelMapDataset/num_parallel_calls" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 64 + } + } + } + } + node_def { + name: "ParallelMapDataset" + op: "ParallelMapDataset" + input: "ShuffleDataset_2:handle:0" + input: "ParallelMapDataset/num_parallel_calls:output:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_74b6b15c" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "PrefetchDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "PrefetchDataset" + op: "PrefetchDataset" + input: "ParallelMapDataset:handle:0" + input: "PrefetchDataset/buffer_size_2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "PrefetchDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 64 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + node_def { + name: "PrefetchDataset_1/buffer_size_3" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 2 + } + } + } + } + node_def { + name: "PrefetchDataset_1" + op: "PrefetchDataset" + input: "FilterDataset:handle:0" + input: "PrefetchDataset_1/buffer_size_3:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 64 + } + dim { + size: 224 + } + dim { + size: 224 + } + dim { + size: 3 + } + } + shape { + dim { + size: 64 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + ret { + key: "PrefetchDataset_1" + value: "PrefetchDataset_1:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_5fa5e1f4"; + std::function mutate_proto_func = + [dataset_name, file_path](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found = false; + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() != "TensorSliceDataset/MatchingFiles/pattern" && + node_def.name() != "ShuffleDataset/MatchingFiles/pattern") + continue; + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found = true; + DCHECK_EQ(node_def.attr().at("value").tensor().string_val(0), + "$(DATA_DIR)"); + VLOG(1) << "Setting the value of node_def " + "TensorSliceDataset/MatchingFiles/pattern to " + << file_path; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(file_path); + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +#endif +} +#endif + +#if not defined(PLATFORM_WINDOWS) +// On success, returns a set of TF_Function instances encoding a dataset +// node stack that reads an MNIST file dataset from `file_path`, and +// sets `dataset_name` to the created dataset name. The returned functions must +// be deleted by calling TF_DeleteFunction. +static std::vector CreateMNISTDatasetFunctions( + const char* file_path, int batch_size, std::string* dataset_name, + TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return nullptr; +#else + const char* func_def = R"PREFIX( +library { + function { + signature { + name: "tf_map_func_521bfd08" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "truediv" + type: DT_FLOAT + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Cast" + op: "Cast" + input: "DecodeRaw:output:0" + attr { + key: "DstT" + value { + type: DT_FLOAT + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 784 + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "Cast:y:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "truediv/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 255.0 + } + } + } + } + node_def { + name: "truediv" + op: "RealDiv" + input: "Reshape:output:0" + input: "truediv/y:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "truediv" + value: "truediv:z:0" + } + } + function { + signature { + name: "tf_map_func_9a08860d" + input_arg { + name: "arg0" + type: DT_STRING + } + output_arg { + name: "ToInt32" + type: DT_INT32 + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "DecodeRaw" + op: "DecodeRaw" + input: "arg0" + attr { + key: "little_endian" + value { + b: true + } + } + attr { + key: "out_type" + value { + type: DT_UINT8 + } + } + } + node_def { + name: "Reshape/shape" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + } + } + } + } + } + } + node_def { + name: "Reshape" + op: "Reshape" + input: "DecodeRaw:output:0" + input: "Reshape/shape:output:0" + attr { + key: "T" + value { + type: DT_UINT8 + } + } + attr { + key: "Tshape" + value { + type: DT_INT32 + } + } + } + node_def { + name: "ToInt32" + op: "Cast" + input: "Reshape:output:0" + attr { + key: "DstT" + value { + type: DT_INT32 + } + } + attr { + key: "SrcT" + value { + type: DT_UINT8 + } + } + } + ret { + key: "ToInt32" + value: "ToInt32:y:0" + } + } + function { + signature { + name: "tf_predicate_7089b845" + input_arg { + name: "arg0" + type: DT_FLOAT + } + input_arg { + name: "arg1" + type: DT_INT32 + } + input_arg { + name: "Equal/Placeholder" + type: DT_INT64 + } + output_arg { + name: "Equal" + type: DT_BOOL + } + description: "A wrapper for Defun that facilitates shape inference." + } + node_def { + name: "Shape" + op: "Shape" + input: "arg0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "out_type" + value { + type: DT_INT64 + } + } + } + node_def { + name: "strided_slice/stack" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 0 + } + } + } + } + node_def { + name: "strided_slice/stack_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice/stack_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 1 + } + } + int_val: 1 + } + } + } + } + node_def { + name: "strided_slice" + op: "StridedSlice" + input: "Shape:output:0" + input: "strided_slice/stack:output:0" + input: "strided_slice/stack_1:output:0" + input: "strided_slice/stack_2:output:0" + attr { + key: "Index" + value { + type: DT_INT32 + } + } + attr { + key: "T" + value { + type: DT_INT64 + } + } + attr { + key: "begin_mask" + value { + i: 0 + } + } + attr { + key: "ellipsis_mask" + value { + i: 0 + } + } + attr { + key: "end_mask" + value { + i: 0 + } + } + attr { + key: "new_axis_mask" + value { + i: 0 + } + } + attr { + key: "shrink_axis_mask" + value { + i: 1 + } + } + } + node_def { + name: "Equal" + op: "Equal" + input: "strided_slice:output:0" + input: "Equal/Placeholder" + attr { + key: "T" + value { + type: DT_INT64 + } + } + } + ret { + key: "Equal" + value: "Equal:z:0" + } + } + function { + signature { + name: "_make_dataset_2451e43a" + output_arg { + name: "FilterDataset" + type: DT_VARIANT + } + is_stateful: true + } + node_def { + name: "FixedLengthRecordDataset/filenames" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-images-idx3-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/header_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 16 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/record_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 784 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/footer_bytes" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset/buffer_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset/filenames:output:0" + input: "FixedLengthRecordDataset/header_bytes:output:0" + input: "FixedLengthRecordDataset/record_bytes:output:0" + input: "FixedLengthRecordDataset/footer_bytes:output:0" + input: "FixedLengthRecordDataset/buffer_size:output:0" + } + node_def { + name: "MapDataset" + op: "MapDataset" + input: "FixedLengthRecordDataset:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_521bfd08" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/filenames_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "$(DATA_DIR)/train-labels-idx1-ubyte" + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/header_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 8 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/record_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 1 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/footer_bytes_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1/buffer_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 262144 + } + } + } + } + node_def { + name: "FixedLengthRecordDataset_1" + op: "FixedLengthRecordDataset" + input: "FixedLengthRecordDataset_1/filenames_1:output:0" + input: "FixedLengthRecordDataset_1/header_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/record_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/footer_bytes_1:output:0" + input: "FixedLengthRecordDataset_1/buffer_size_1:output:0" + } + node_def { + name: "MapDataset_1" + op: "MapDataset" + input: "FixedLengthRecordDataset_1:handle:0" + attr { + key: "Targuments" + value { + list { + } + } + } + attr { + key: "f" + value { + func { + name: "tf_map_func_9a08860d" + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_INT32 + } + } + } + } + node_def { + name: "ZipDataset" + op: "ZipDataset" + input: "MapDataset:handle:0" + input: "MapDataset_1:handle:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "CacheDataset/filename" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "" + } + } + } + } + node_def { + name: "CacheDataset" + op: "CacheDataset" + input: "ZipDataset:handle:0" + input: "CacheDataset/filename:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "RepeatDataset/count" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -1 + } + } + } + } + node_def { + name: "RepeatDataset" + op: "RepeatDataset" + input: "CacheDataset:handle:0" + input: "RepeatDataset/count:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "ShuffleDataset/buffer_size_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 50000 + } + } + } + } + node_def { + name: "ShuffleDataset/seed" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset/seed2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: 0 + } + } + } + } + node_def { + name: "ShuffleDataset" + op: "ShuffleDataset" + input: "RepeatDataset:handle:0" + input: "ShuffleDataset/buffer_size_2:output:0" + input: "ShuffleDataset/seed:output:0" + input: "ShuffleDataset/seed2:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: 784 + } + } + shape { + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "reshuffle_each_iteration" + value { + b: true + } + } + } + node_def { + name: "BatchDataset/batch_size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "BatchDataset" + op: "BatchDataset" + input: "ShuffleDataset:handle:0" + input: "BatchDataset/batch_size:output:0" + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + } + node_def { + name: "FilterDataset/batch_size_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + } + int64_val: -123 + } + } + } + } + node_def { + name: "FilterDataset" + op: "FilterDataset" + input: "BatchDataset:handle:0" + input: "FilterDataset/batch_size_1:output:0" + attr { + key: "Targuments" + value { + list { + type: DT_INT64 + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 784 + } + } + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + type: DT_INT32 + } + } + } + attr { + key: "predicate" + value { + func { + name: "tf_predicate_7089b845" + } + } + } + } + ret { + key: "FilterDataset" + value: "FilterDataset:handle:0" + } + } +} +)PREFIX"; + + *dataset_name = "_make_dataset_2451e43a"; + std::function mutate_proto_func = + [dataset_name, file_path, batch_size](FunctionDef* fdef) { + VLOG(1) << "Processsing function " << fdef->DebugString(); + if (std::string(fdef->signature().name()) != *dataset_name) return; + // Change the input file pattern to `file_path`. + bool found_file_path = false, found_batch_size = false; + // `node_def` may be mutated. + for (auto& node_def : *fdef->mutable_node_def()) { + if (node_def.name() == "FixedLengthRecordDataset/filenames" || + node_def.name() == "FixedLengthRecordDataset_1/filenames_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_file_path = true; + // Replace $(DATA_DIR)/foo with /foo + // TODO(hongm): Use StringPiece manipulation for better efficiency. + const std::string cur_value = + node_def.attr().at("value").tensor().string_val(0); + const std::string pattern = "$(DATA_DIR)"; + DCHECK_EQ(cur_value.compare(0, pattern.length(), pattern), 0); + const std::string new_value = + file_path + cur_value.substr(pattern.length()); + VLOG(1) << "Setting the value of node_def " << node_def.name() + << " to " << new_value; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_string_val(); + tensor->add_string_val(new_value); + } else if (node_def.name() == "BatchDataset/batch_size" || + node_def.name() == "FilterDataset/batch_size_1") { + DCHECK_EQ(node_def.op(), "Const"); + DCHECK_GT(node_def.attr().count("value"), 0); + found_batch_size = true; + // Replace $(BATCH_SIZE) with `batch_size` + DCHECK_EQ(node_def.attr().at("value").tensor().int64_val(0), -123); + VLOG(1) << "Setting the batch size attr value of node_def " + << node_def.name() << " to " << batch_size; + auto* tensor = (*node_def.mutable_attr())["value"].mutable_tensor(); + tensor->clear_int64_val(); + tensor->add_int64_val(batch_size); + } + } + VLOG(1) << "Rewrote function to " << fdef->DebugString(); + DCHECK(found_file_path); + DCHECK(found_batch_size); + }; + return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); +#endif +} +#endif + +// Adds the input functions to `graph`. On success, returns the created +// IteratorGetNext node. +static TF_Operation* AddDatasetFunctionAndIteratorNodesToGraph( + const std::vector& funcs, const std::string& dataset_name, + const std::vector& output_types, + const std::vector& output_shapes, + TF_Graph* graph, TF_Status* status) { + DCHECK(!dataset_name.empty()); + for (auto& func : funcs) { + TF_GraphCopyFunction(graph, func.get(), /*gradient*/ nullptr, status); + if (!status->status.ok()) { + return nullptr; + } + } + + tensorflow::mutex_lock c(graph->mu); + + tensorflow::NameAttrList func; + func.set_name(dataset_name); + // Run the iterator node on CPU. + Node* oneshot_iterator_node; + tensorflow::Status s = NodeBuilder("OneShotIterator", "OneShotIterator") + .Device("/device:CPU:0") + .Attr("container", "") + .Attr("dataset_factory", func) + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Attr("shared_name", "") + .Finalize(&graph->graph, &oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(oneshot_iterator_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + // Run the iterator node on CPU. + Node* getnext_node; + s = NodeBuilder("IteratorGetNext", "IteratorGetNext") + .Input(oneshot_iterator_node) + .Device("/device:CPU:0") + .Attr("output_types", output_types) + .Attr("output_shapes", output_shapes) + .Finalize(&graph->graph, &getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + // Run shape inference function for each newly added node, so that more + // subsequent nodes can be added to the graph via C API (TF_NewOperation()). + s = graph->refiner.AddNode(getnext_node); + if (!s.ok()) { + status->status = s; + return nullptr; + } + + VLOG(1) << "Output graph: " << graph->graph.ToGraphDefDebug().DebugString(); + return ToTF_Operation(getnext_node); +} + +TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph, + TF_Status* status) { + tensorflow::Status s; + + std::string dataset_name; + UniqueFuncPtr result_func = CreateFakeDatasetFunction(&dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector funcs; + funcs.push_back(std::move(result_func)); + std::vector output_shape_list; + output_shape_list.push_back(tensorflow::TensorShapeProto()); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT}, output_shape_list, graph, + status); + if (!status->status.ok()) { + return nullptr; + } + + return getnext_node; +} + +TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + // TODO(ashankar): get these functions working on Windows. + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return nullptr; +#else + tensorflow::Status s; + + std::string dataset_name; + const auto& funcs = + is_mnist + ? CreateMNISTDatasetFunctions(file_path, batch_size, &dataset_name, + status) + : CreateImagenetDatasetFunctions(file_path, &dataset_name, status); + if (!status->status.ok()) { + return nullptr; + } + + std::vector output_shape_list; + // batch_size X 224 X 224 X 3 + auto image_shape = tensorflow::TensorShapeProto(); + image_shape.add_dim()->set_size(batch_size); + if (is_mnist) { + image_shape.add_dim()->set_size(784); + } else { + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(224); + image_shape.add_dim()->set_size(3); + } + output_shape_list.push_back(image_shape); + + // batch_size + auto label_shape = tensorflow::TensorShapeProto(); + label_shape.add_dim()->set_size(batch_size); + output_shape_list.push_back(label_shape); + auto* getnext_node = AddDatasetFunctionAndIteratorNodesToGraph( + funcs, dataset_name, {tensorflow::DT_FLOAT, tensorflow::DT_INT32}, + output_shape_list, graph, status); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::mutex_lock c(graph->mu); + VLOG(1) << "The extended graph: " + << graph->graph.ToGraphDefDebug().DebugString(); + + return getnext_node; +#endif +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..88cb173cd25f4219e32392f6722a6ea7d358a553 --- /dev/null +++ b/tensorflow/c/c_api_experimental.h @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_C_API_EXPERIMENTAL_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlow. +// +// The API here is subject to changes in the future. +// -------------------------------------------------------------------------- + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// When `enable` is true, set +// tensorflow.ConfigProto.OptimizerOptions.global_jit_level to ON_1, and also +// set XLA flag values to prepare for XLA compilation. Otherwise set +// global_jit_level to OFF. +// +// This API is syntax sugar over TF_SetConfig(), and is used by clients that +// cannot read/write the tensorflow.ConfigProto proto. +TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, + unsigned char enable); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Creates a stack of data set + iterator nodes, currently hard-coded to return +// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success, +// returns the IteratorGetNext node, which caller can run or feed into an node. +// +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets( + TF_Graph* graph, TF_Status* status); + +// Similar to the above API, except that the returned iterator reads the +// file based dataset from `file_path`. +// If `is_mnist` is 0, the dataset corresponds to ImageNet. +// The iterators outputs 2 tensors: +// - A float tensor of shape `batch_size` X 784 when `is_mnist` is non-zero, or +// `batch_size` X 224 X 224 X 3 otherwise. +// - An int32 tensor of shape `batch_size` +// TODO(hongm): Extend the API to allow customization of the nodes created. +TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( + TF_Graph* graph, const char* file_path, int batch_size, + unsigned char is_mnist, TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30fcfd401d9d634962d64aaa3bf348de91f2ecae --- /dev/null +++ b/tensorflow/c/c_api_experimental_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +void TestFakeIteratorStack() { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + const float base_value = 42.0; + for (int i = 0; i < 3; ++i) { + csession.SetOutputs({get_next}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(out)); + ASSERT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(float), TF_TensorByteSize(out)); + float* output_contents = static_cast(TF_TensorData(out)); + ASSERT_EQ(base_value + i, *output_contents); + } + + // This should error out since we've exhausted the iterator. + csession.Run(s); + ASSERT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s)) << TF_Message(s); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); } + +TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + const string file_path = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record"); + VLOG(1) << "data file path is " << file_path; + const int batch_size = 64; + TF_Operation* get_next = TF_MakeFileBasedIteratorGetNextWithDatasets( + graph, file_path.c_str(), batch_size, /*is_mnist*/ false, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + // The two output tensors should look like: + // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32) + // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32) + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Running iter " << i; + csession.SetOutputs({{get_next, 0}, {get_next, 1}}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + { + TF_Tensor* image = csession.output_tensor(0); + ASSERT_TRUE(image != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(image)); + // Confirm shape is 224 X 224 X 3 + ASSERT_EQ(4, TF_NumDims(image)); + ASSERT_EQ(batch_size, TF_Dim(image, 0)); + ASSERT_EQ(224, TF_Dim(image, 1)); + ASSERT_EQ(224, TF_Dim(image, 2)); + ASSERT_EQ(3, TF_Dim(image, 3)); + ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3, + TF_TensorByteSize(image)); + } + + { + TF_Tensor* label = csession.output_tensor(1); + ASSERT_TRUE(label != nullptr); + ASSERT_EQ(TF_INT32, TF_TensorType(label)); + ASSERT_EQ(1, TF_NumDims(label)); + ASSERT_EQ(batch_size, TF_Dim(label, 0)); + ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label)); + } + } + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 7ca50119eafe299b307f06c555aec1388e7e82e2..610274696f5940c063e68f2310cfd9cc1e0bd964 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 91667056e0eeb224b4b8a034766f11a123cd1a03..95652a11378d6276b5ba6540a07baa15aa77cc1c 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -84,19 +84,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // The keys of this map are all the active sessions using this graph. - // Each value is the current "runnability" status of the corresponding - // session. Under normal conditions all statuses are Status::OK(), but - // if some operation is mutated after it was run by a session (this - // is detected in RecordMutation function), that session is no longer - // safe to run. Its status will contain the error that will be returned - // to the user, should she try running this session. + // The keys of this map are all the active sessions using this graph. Each + // value records whether the graph has been mutated since the corresponding + // session has been run (this is detected in RecordMutation function). If the + // string is empty, no mutation has occurred. Otherwise the string is a + // description of the mutation suitable for returning to the user. // // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. // TF_Graph may only / must be deleted when // sessions.size() == 0 && delete_requested == true - tensorflow::gtl::FlatMap sessions + // + // TODO(b/74949947): mutations currently trigger a warning instead of a bad + // status, this should be reverted when possible. + tensorflow::gtl::FlatMap sessions GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph @@ -124,15 +125,16 @@ struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; - TF_Graph* graph; + TF_Graph* const graph; - tensorflow::mutex mu; + tensorflow::mutex mu ACQUIRED_AFTER(TF_Graph::mu); int last_num_graph_nodes; - // NOTE(ashankar): Experimental fields to help keep the - // buffers of a TF_Tensor pinned in device memory. - const tensorflow::DeviceMgr* device_mgr; // Owned by session. - std::vector devices; // Owned by device_mgr. + // If true, TF_SessionRun and similar methods will call + // ExtendSessionGraphHelper before running the graph (this is the default + // public behavior). Can be set to false if the caller needs to call + // ExtendSessionGraphHelper manually. + std::atomic extend_before_run; }; struct TF_ImportGraphDefOptions { @@ -210,7 +212,11 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, TF_Status* status); void RecordMutation(TF_Graph* graph, const TF_Operation& op, - const char* mutation_type); + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu); + +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) + LOCKS_EXCLUDED(session->graph->mu, session->mu); } // end namespace tensorflow diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 028f146be31790b211e546978302e81afe26b231..ca80db23ed3ccbbdc49c61db6cd03ff735470512 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -53,7 +53,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index a55af46ae2baef1cd4f55f478ec234551f370503..f3b28c1708129d39e451d927a89c0d10e2193b63 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -34,6 +34,10 @@ static void DoubleDeallocator(void* data, size_t, void* arg) { delete[] static_cast(data); } +static void FloatDeallocator(void* data, size_t, void* arg) { + delete[] static_cast(data); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -78,21 +82,34 @@ TF_Tensor* DoubleTensor(double v) { &DoubleDeallocator, nullptr); } +TF_Tensor* FloatTensor(float v) { + const int num_bytes = sizeof(float); + float* values = new float[1]; + values[0] = v; + return TF_NewTensor(TF_FLOAT, nullptr, 0, values, num_bytes, + &FloatDeallocator, nullptr); +} + // All the *Helper methods are used as a workaround for the restrictions that // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_DataType dtype, const std::vector& dims, TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); - TF_SetAttrType(desc, "dtype", TF_INT32); + TF_SetAttrType(desc, "dtype", dtype); + if (!dims.empty()) { + TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } *op = TF_FinishOperation(desc, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_NE(*op, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name, + TF_DataType dtype, const std::vector& dims) { TF_Operation* op; - PlaceholderHelper(graph, s, name, &op); + PlaceholderHelper(graph, s, name, dtype, dims, &op); return op; } @@ -126,6 +143,12 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(FloatTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op, bool check) { @@ -404,19 +427,7 @@ std::vector GetFuncNames(const tensorflow::GraphDef& graph_def) { CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) { TF_SessionOptions* opts = TF_NewSessionOptions(); - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); - flags->tf_xla_cpu_global_jit = use_XLA; - if (use_XLA) { - tensorflow::ConfigProto config; - config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); - std::string contents; - contents.resize(config.ByteSizeLong()); - config.SerializeToArray(&contents[0], contents.size()); - TF_SetConfig(opts, contents.data(), contents.size(), s); - } + TF_EnableXLACompilation(opts, use_XLA); session_ = TF_NewSession(graph, opts, s); TF_DeleteSessionOptions(opts); } diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 2a70177c724c569844a5d8ad42b99bed20209946..cd19cf8d624d9b914b61132f93d918b046cdbd30 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -44,8 +44,12 @@ TF_Tensor* Int32Tensor(int32_t v); TF_Tensor* DoubleTensor(double v); +TF_Tensor* FloatTensor(float v); + TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, - const char* name = "feed"); + const char* name = "feed", + TF_DataType dtype = TF_INT32, + const std::vector& dims = {}); TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name = "const"); @@ -56,6 +60,9 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index e55cb672e97e1403a3dd864c91c176426eb3f067..a2d96357ac8a55be7fe03bf58e33ff1733967dd1 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -27,6 +27,14 @@ tf_cuda_library( ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:execute_node", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -54,11 +62,17 @@ tf_cuda_library( ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) @@ -93,6 +107,7 @@ tf_cuda_library( "//conditions:default": [ "//tensorflow/c:c_api", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8e834eb99c13d1f26da9f0860897267efc2fd01c..393851d13c9b1ad0184f2778fa11afb271bd241b 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -31,8 +31,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/execute_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -40,6 +45,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" @@ -65,6 +71,7 @@ string DeviceName(const tensorflow::Device* d) { #ifdef TENSORFLOW_EAGER_USE_XLA std::atomic_int_fast64_t func_id_generator(0); #endif // TENSORFLOW_EAGER_USE_XLA + } // namespace extern "C" { @@ -76,181 +83,144 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, TF_SetConfig(&options->session_options, proto, proto_len, status); } +void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, + unsigned char async) { + options->async = async; +} void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, + unsigned char async, + TF_Status* status) { + status->status = ctx->context.SetAsyncForThread(async); +} + void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { - TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, &opts->session_options, status); - if (status->status.ok()) { - if (session->device_mgr == nullptr || session->devices.empty()) { - status->status = tensorflow::errors::InvalidArgument( - "Provided TF_SessionOptions are not compatible with eager execution " - "(perhaps the TF_SessionOptions alluded to session execution in a " - "remote address space?)"); - } - } + std::vector devices; + status->status = tensorflow::DeviceFactory::AddDevices( + opts->session_options.options, "/job:localhost/replica:0/task:0", + &devices); if (!status->status.ok()) { - TF_DeleteGraph(graph); return nullptr; } - - return new TFE_Context(*opts, session); + std::unique_ptr device_mgr( + new tensorflow::DeviceMgr(devices)); + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, std::move(device_mgr), r); } -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - status->status = tensorflow::Status::OK(); - { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); - } - TF_Graph* graph = ctx->session->graph; - TF_DeleteSession(ctx->session, status); - TF_DeleteGraph(graph); - ctx->rendezvous->Unref(); - delete ctx; -} +void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { - return TF_SessionListDevices(ctx->session, status); + TF_DeviceList* list = new TF_DeviceList; + ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + return list; } -void TFE_ContextClearCaches(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); -} +void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context.ClearCaches(); } void TFE_ContextSetThreadLocalDevicePlacementPolicy( TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - ctx->thread_local_policies[std::this_thread::get_id()] = policy; + ctx->context.SetThreadLocalDevicePlacementPolicy( + static_cast(policy)); } +// Note: this function looks up a thread local policy. So it should be called in +// the appropriate client thread. In particular, in async mode, it may not be +// safe to call this function from the async EagerExecutor threads. extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->policy_map_mu); - auto policy_map_it = - ctx->thread_local_policies.find(std::this_thread::get_id()); - if (policy_map_it != ctx->thread_local_policies.end()) { - return policy_map_it->second; - } - return ctx->policy; + return static_cast( + ctx->context.GetDevicePlacementPolicy()); +} + +void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.AsyncWait(); +} + +void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { + status->status = ctx->context.GetStatus(); +} + +void TFE_ContextAsyncClearError(TFE_Context* ctx) { + ctx->context.ClearAsyncError(); } TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); if (!status->status.ok()) return nullptr; - return new TFE_TensorHandle(tensor, nullptr); + return new TFE_TensorHandle(tensor, nullptr, nullptr); } -void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } +void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { + DCHECK(h); + if (h->handle) { + h->handle->Unref(); + } + delete h; +} TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast(h->t.dtype()); + return static_cast(h->handle->dtype); } -int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); } - -int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { - return h->t.dim_size(dim_index); +int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->Tensor(&t); + return t == nullptr ? 0 : t->dims(); } -const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { - // This might be a bit confusing as a tensor on CPU can sometimes return - // "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0". - // TODO(ashankar): Figure out which one would be nicer. - return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str(); +int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, + TF_Status* status) { + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->Tensor(&t); + return t == nullptr ? 0 : t->dim_size(dim_index); } -TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { - if (!IsCPU(h->d)) { - TF_SetStatus(status, TF_UNIMPLEMENTED, - tensorflow::strings::StrCat( - "TFE_TensorHandle can be resolved iff it is on CPU (this " - "handle is on ", - h->d->name(), - "). Consider using TFE_TensorHandleCopyToDevice to get a " - "copy of the tensor on CPU") - .c_str()); - return nullptr; - } - return tensorflow::TF_TensorFromTensor(h->t, status); +const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + tensorflow::Device* d = nullptr; + status->status = h->handle->OpDevice(&d); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); } -TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status) { - tensorflow::Device* dstd = ctx->devices()[0]; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; +TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { + // TODO(agarwal): move this implementation inside TFE_TensorHandle. + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + tensorflow::TensorHandle* h_cpu = nullptr; + if (!IsCPU(d)) { + status->status = h->handle->CopyToDevice( + h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); + if (!status->status.ok()) { + return nullptr; + } + status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) { + h_cpu->Unref(); + return nullptr; + } } - - tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d; - bool is_same_device = - (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); - const bool dst_cpu = IsCPU(dstd); - const bool src_cpu = IsCPU(srcd); - // both_on_cpu can be true and yet is_same_device is false, if one of src/dst - // has device type XLA_CPU, and the other CPU. - const bool both_on_cpu = src_cpu && dst_cpu; - if (is_same_device || both_on_cpu) { - return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); - } - tensorflow::Tensor* src = &(h->t); - if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && - !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat("Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), - " to device ", DeviceName(dstd), ".") - .c_str()); - return nullptr; + TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); + if (h_cpu != nullptr) { + h_cpu->Unref(); } - tensorflow::AllocatorAttributes attr; - if (src->dtype() == tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); - if (src->shape().num_elements() == 0) { - return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); - } - tensorflow::DeviceContext* src_device_context = nullptr; - if (!src_cpu) { - src_device_context = srcd->tensorflow_gpu_device_info()->default_context; - } - tensorflow::DeviceContext* dst_device_context = nullptr; - if (!dst_cpu) { - dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; - } - // TODO(ashankar): The Sync() call below may be more aggressive than - // necessary. It is based on knowledge of implementation details - that - // GPU devices are implemented using 3 streams - one for host->device copies, - // one for device->host copies and one for sending operations to the GPU. - // With that setup, Sync()ing across all 3 streams should be sufficient - // but more than necessary (since it waits for operations that might have - // nothing to do with this tensor to complete). - status->status = srcd->Sync(); - tensorflow::Notification n; - tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, - srcd, dstd, tensorflow::AllocatorAttributes(), - tensorflow::AllocatorAttributes(), src, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); - n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) - ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) - : nullptr; + return retval; } +} // extern "C" + +extern "C" { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { @@ -259,8 +229,7 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, status->status = tensorflow::AttrTypeMapForOp(name, &types); if (status->status.ok()) return new TFE_Op(ctx, name, types); if (TF_GetCode(status) == TF_NOT_FOUND) { - tensorflow::mutex_lock l(ctx->functions_mu); - if (ctx->func_lib_def.Find(name) != nullptr) { + if (ctx->context.FindFunctionByName(name)) { status->status = tensorflow::Status::OK(); return new TFE_Op(ctx, name, nullptr); } @@ -273,16 +242,14 @@ void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - op->ctx->session->device_mgr->LookupDevice(device_name, &d); - if (!status->status.ok()) return; + status->status = op->ctx->context.FindDeviceByName(device_name, &d); } op->device = d; } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->devices()[0] : op->device; + (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; return device->name().c_str(); } @@ -295,17 +262,8 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - // Questionable heuristic ... - // - // Motivation: After an 'op' is placed on GPU because some of its earlier - // inputs are on GPU, we want to keep the 'op' there, even if some later - // inputs of it are not on GPU. - if (IsCPU(op->device) && !IsCPU(h->d)) { - op->device = h->d; - } - if (!status->status.ok()) return; - op->inputs.push_back(h->t); - op->input_devices.push_back(h->d); + h->handle->Ref(); + op->inputs.push_back(h->handle); op->attrs.NumInputs(op->inputs.size()); } @@ -467,14 +425,43 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, tensorflow::gtl::ArraySlice( funcs.get(), num_values)); } +} // extern "C" namespace { +// Initializes the step stats if needed. +void MaybeInitializeStepStats(tensorflow::StepStats* step_stats, + tensorflow::EagerContext* ctx) { + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->devices()->size()) { + int device_idx = step_stats->dev_stats_size(); + auto* dev_stats = step_stats->add_dev_stats(); + dev_stats->set_device(ctx->devices()->at(device_idx)->name()); + } +} + +int StepStatsDeviceIndex(tensorflow::StepStats* step_stats, + tensorflow::EagerContext* ctx, + tensorflow::Device* device) { + // Find the current device's index. + if (device == nullptr) { + device = ctx->HostCPU(); + } + for (int i = 0; i < ctx->devices()->size(); ++i) { + if (ctx->devices()->at(i) == device || + ctx->devices()->at(i)->name() == device->name()) { + return i; + } + } + // TODO(apassos) do not fall back to host CPU if device is unknown. + return 0; +} + tensorflow::Status ValidateInputTypeAndPlacement( - TFE_Context* ctx, tensorflow::Device* host_device, - tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel, - std::vector* copied_tensors) { + tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) { + tensorflow::Device* host_device = ctx->HostCPU(); const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -483,20 +470,23 @@ tensorflow::Status ValidateInputTypeAndPlacement( for (int i = 0; i < op->inputs.size(); ++i) { const tensorflow::Device* expected_device = memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; + tensorflow::TensorHandle* handle = op->inputs[i]; + tensorflow::Device* handle_device = nullptr; + TF_RETURN_IF_ERROR(handle->Device(&handle_device)); const tensorflow::Device* actual_device = - op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; + handle_device == nullptr ? host_device : handle_device; if (expected_device != actual_device) { - switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { - case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: + switch (ctx->GetDevicePlacementPolicy()) { + case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. - if (op->inputs[i].dtype() == tensorflow::DT_INT32) { + if (handle->dtype == tensorflow::DT_INT32) { // Note: enabling silent copies of int32 tensors to match behavior // of graph mode. break; } TF_FALLTHROUGH_INTENDED; - case TFE_DEVICE_PLACEMENT_EXPLICIT: + case tensorflow::DEVICE_PLACEMENT_EXPLICIT: return tensorflow::errors::InvalidArgument( "Tensors on conflicting devices:" " cannot compute ", @@ -504,11 +494,13 @@ tensorflow::Status ValidateInputTypeAndPlacement( expected_device->name(), " but is actually on ", actual_device->name(), " (operation running on ", op_device->name(), ")", - " Tensors can be copied explicitly using .gpu() or .cpu()," - " or transparently copied by using tfe.enable_eager_execution(" - "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" + " Tensors can be copied explicitly using .gpu() or .cpu() " + "methods," + " or transparently copied by using tf.enable_eager_execution(" + "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors " + "between devices" " may slow down your model"); - case TFE_DEVICE_PLACEMENT_WARN: + case tensorflow::DEVICE_PLACEMENT_WARN: LOG(WARNING) << "before computing " << op->name << " input #" << i << " was expected to be on " << expected_device->name() << " but is actually on " << actual_device->name() @@ -516,40 +508,77 @@ tensorflow::Status ValidateInputTypeAndPlacement( << "). This triggers a copy which can be a performance " "bottleneck."; break; - case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing. break; } // We are only here if the policy is warn or silent copies, so we should // trigger a copy. - TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; - TF_Status* s = TF_NewStatus(); - TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - &original, ctx, expected_device->name().c_str(), s); - if (!s->status.ok()) { - tensorflow::Status status = s->status; - delete s; + auto pre_time = tensorflow::Env::Default()->NowMicros(); + tensorflow::TensorHandle* copied_tensor = nullptr; + tensorflow::Status status = tensorflow::EagerCopyToDevice( + handle, ctx, expected_device->name().c_str(), &copied_tensor); + if (run_metadata != nullptr) { + auto* step_stats = run_metadata->mutable_step_stats(); + MaybeInitializeStepStats(step_stats, ctx); + // Record the sending on the source device for now. + int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device); + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + auto* node_stats = dev_stats->add_node_stats(); + node_stats->set_node_name("_Send"); + node_stats->set_all_start_micros(pre_time); + node_stats->set_op_end_rel_micros( + tensorflow::Env::Default()->NowMicros() - pre_time); + } + if (!status.ok()) { + if (copied_tensor != nullptr) copied_tensor->Unref(); return tensorflow::errors::Internal( "Failed copying input tensor from ", actual_device->name(), " to ", expected_device->name(), " in order to run ", op->name, ": ", status.error_message()); } - op->inputs[i] = copied_tensor->t; - copied_tensors->push_back(copied_tensor); - op->input_devices[i] = copied_tensor->d; - delete s; + handle->Unref(); + handle = copied_tensor; + op->inputs[i] = copied_tensor; } - if (op->inputs[i].dtype() != kernel->input_type(i)) { + if (handle->dtype != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( "cannot compute ", op->name, " as input #", i, " was expected to be a ", tensorflow::DataTypeString(kernel->input_type(i)), - " tensor but is a ", - tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor"); + " tensor but is a ", tensorflow::DataTypeString(handle->dtype), + " tensor"); } } return tensorflow::Status::OK(); } +tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, + TFE_Context* ctx, TF_Status* status) { + tensorflow::DeviceSet ds; + for (tensorflow::Device* d : *ctx->context.devices()) { + ds.AddDevice(d); + } + tensorflow::DeviceTypeVector final_devices; + status->status = tensorflow::SupportedDeviceTypesForNode( + ds.PrioritizedDeviceTypeList(), ndef, &final_devices); + if (!status->status.ok()) { + return nullptr; + } + if (final_devices.empty()) { + status->status = tensorflow::errors::Internal( + "Could not find valid device for node ", ndef.DebugString()); + return nullptr; + } + for (tensorflow::Device* d : *ctx->context.devices()) { + if (d->device_type() == final_devices[0].type_string()) { + return d; + } + } + status->status = tensorflow::errors::Unknown( + "Could not find a device for node ", ndef.DebugString()); + return nullptr; +} + #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a // primitive op (e.g. matmul). @@ -577,8 +606,7 @@ const tensorflow::FunctionDef* OpToFunction( TFE_Context* ctx = op->ctx; const tensorflow::OpRegistrationData* op_data; { - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + status->status = ctx->context.FindFunctionOpData(op->name, &op_data); if (!status->status.ok()) { return nullptr; } @@ -615,7 +643,7 @@ const tensorflow::FunctionDef* OpToFunction( (*op_input_to_func_input)[i] = const_index; func_input_arg = signature->mutable_input_arg(const_index++); const_input_types->push_back( - static_cast(op->inputs[i].dtype())); + static_cast(op->inputs[i]->dtype)); } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { VLOG(1) << "For resource input, mapping op input " << i << " to func input " << resource_index; @@ -627,11 +655,11 @@ const tensorflow::FunctionDef* OpToFunction( (*op_input_to_func_input)[i] = arg_index; func_input_arg = signature->mutable_input_arg(arg_index++); arg_input_types->push_back( - static_cast(op->inputs[i].dtype())); + static_cast(op->inputs[i]->dtype)); } func_input_arg->set_name(op_input_arg.name()); - func_input_arg->set_type(op->inputs[i].dtype()); + func_input_arg->set_type(op->inputs[i]->dtype); } VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); @@ -674,10 +702,9 @@ const tensorflow::FunctionDef* OpToFunction( } VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(fdef); + status->status = ctx->context.AddFunctionDef(fdef); if (!status->status.ok()) return nullptr; - const auto ret = ctx->func_lib_def.Find(signature->name()); + const auto ret = ctx->context.FindFunctionDef(signature->name()); DCHECK(ret != nullptr); return ret; } @@ -695,10 +722,7 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { } const tensorflow::FunctionDef* fdef; - { - tensorflow::tf_shared_lock l(op->ctx->functions_mu); - fdef = op->ctx->func_lib_def.Find(op->name); - } + { fdef = op->ctx->context.FindFunctionDef(op->name); } std::vector const_input_types; std::vector arg_input_types; tensorflow::gtl::FlatMap op_input_to_func_input; @@ -724,21 +748,16 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { // Since input param reordering may have occurred between `op` and `launch_op` // via `op_input_to_func_input`, adjust the actual inputs accordingly. launch_op->inputs = op->inputs; - launch_op->input_devices = op->input_devices; + for (tensorflow::TensorHandle* h : launch_op->inputs) { + h->Ref(); + } if (!op_input_to_func_input.empty()) { DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); - if (!op->input_devices.empty()) { - DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); - } for (int i = 0; i < op_input_to_func_input.size(); ++i) { VLOG(1) << "mapping op input " << i << " to func input " << op_input_to_func_input[i]; launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; - if (!op->input_devices.empty()) { - launch_op->input_devices[op_input_to_func_input[i]] = - op->input_devices[i]; - } } } launch_op->attrs.NumInputs(op->inputs.size()); @@ -771,15 +790,18 @@ std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { return launch_op; } #endif // TENSORFLOW_EAGER_USE_XLA + } // namespace +extern "C" { + void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { TFE_Context* ctx = op->ctx; - // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU - tensorflow::Device* device = - (op->device == nullptr) ? ctx->devices()[0] : op->device; - + status->status = ctx->context.GetStatus(); + if (!status->status.ok()) { + return; + } #ifdef TENSORFLOW_EAGER_USE_XLA std::unique_ptr xla_launch_op; if (op->use_xla && op->name != "_XlaLaunch") { @@ -790,45 +812,102 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, op = xla_launch_op.get(); } #endif // TENSORFLOW_EAGER_USE_XLA - - std::vector outputs(1); - const tensorflow::MemoryTypeVector* output_memory_types = nullptr; - tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); - tensorflow::KernelAndDevice* kernel; - { - tensorflow::tf_shared_lock l(ctx->cache_mu); - kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + // Ensure all resource-touching ops run in the device the resource is, + // regardless of anything else that has been specified. This is identical to + // the graph mode behavior. + for (int i = 0; i < op->inputs.size(); ++i) { + tensorflow::Device* input_op_device = nullptr; + status->status = op->inputs[i]->OpDevice(&input_op_device); + if (!status->status.ok()) return; + VLOG(2) << "for op " << op->name << " input " << i << " " + << tensorflow::DataTypeString(op->inputs[i]->dtype) << " " + << (input_op_device == nullptr ? "cpu" : input_op_device->name()) + << " " << (op->device == nullptr ? "cpu" : op->device->name()); + if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && + (input_op_device != op->device || input_op_device == nullptr)) { + tensorflow::Device* d = + input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; + VLOG(1) << "Changing device of operation " << op->name << " to " + << d->name() << " because input #" << i + << " is a resource in this device."; + op->device = d; + } } + tensorflow::Device* device = op->device; + + tensorflow::Fprint128 cache_key = + op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); + tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); + if (device == nullptr) { + device = SelectDevice(ndef, ctx, status); + if (!status->status.ok()) { + return; + } + } + CHECK(device != nullptr); + if (ctx->context.LogDevicePlacement()) { + LOG(INFO) << "Executing op " << ndef.op() << " in device " + << device->name(); + } + kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def // will be accessed, so grab on to the lock. - // See WARNING comment below - would be nice to rework to avoid this - // subtlety. - tensorflow::tf_shared_lock l(ctx->functions_mu); - status->status = - tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); + // See WARNING comment in Execute (before kernel->Run) - would be nice to + // rework to avoid this subtlety. + tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); + status->status = tensorflow::KernelAndDevice::Init( + ndef, ctx->context.func_lib(device), kernel); if (!status->status.ok()) { delete kernel; return; } - tensorflow::mutex_lock ml(ctx->cache_mu); - tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); - } - std::vector copied_tensors; - status->status = ValidateInputTypeAndPlacement( - ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); - output_memory_types = &kernel->kernel()->output_memory_types(); - if (!status->status.ok()) { - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); + // Update output_dtypes inside `kernel`. + const tensorflow::OpDef* op_def = nullptr; + const tensorflow::FunctionDef* function_def = + ctx->context.FuncLibDef()->Find(ndef.op()); + if (function_def != nullptr) { + op_def = &(function_def->signature()); + } + if (op_def == nullptr) { + status->status = OpDefForOp(ndef.op().c_str(), &op_def); + if (!status->status.ok()) { + return; + } } + tensorflow::DataTypeVector input_dtypes; + status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, + kernel->mutable_output_dtypes()); + if (!status->status.ok()) { + return; + } + ctx->context.AddKernelToCache(cache_key, kernel); + } + const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); + const int output_dtypes_size = output_dtypes.size(); + if (output_dtypes_size > *num_retvals) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + tensorflow::strings::StrCat("Expecting ", output_dtypes.size(), + " outputs, but *num_retvals is ", + *num_retvals) + .c_str()); return; } + *num_retvals = output_dtypes_size; + if (device == nullptr) { + // TODO(apassos) debug how the assignment below might return a different + // device from the one requested above. + device = kernel->device(); + } + status->status = ValidateInputTypeAndPlacement( + &ctx->context, device, op, kernel->kernel(), + ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto() + : nullptr); + if (!status->status.ok()) return; std::unique_ptr maybe_stats; - if (ctx->should_store_metadata.load()) { + if (ctx->context.ShouldStoreMetadata()) { maybe_stats.reset(new tensorflow::NodeExecStats); maybe_stats->set_node_name(op->name); maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); @@ -836,56 +915,50 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); // TODO(apassos) track referenced tensors } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, - // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. - status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); - for (auto* t : copied_tensors) { - TFE_DeleteTensorHandle(t); - } - if (!status->status.ok()) return; - if (maybe_stats != nullptr) { - maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - - maybe_stats->all_start_micros()); - tensorflow::mutex_lock ml(ctx->metadata_mu); - if (ctx->should_store_metadata.load()) { - auto* step_stats = ctx->run_metadata.mutable_step_stats(); - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices().size()) { - step_stats->add_dev_stats(); - } - // Find the current device's index. - int device_idx = 0; - for (int i = 0; i < ctx->devices().size(); ++i) { - if (ctx->devices()[i] == device) { - device_idx = i; - break; - } - } - // Populate the device stats for this device. - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - dev_stats->set_device(device->name()); - *dev_stats->add_node_stats() = *maybe_stats; + if (ctx->context.Async()) { + // Note that for async mode, execution order will make sure that all + // input handles are ready before executing them. + // TODO(agarwal): Consider executing "cheap" kernels inline for performance. + tensorflow::gtl::InlinedVector handle_retvals( + *num_retvals); + tensorflow::uint64 id = op->ctx->context.NextId(); + for (int i = 0; i < *num_retvals; ++i) { + tensorflow::TensorHandle* h = + new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context); + retvals[i] = new TFE_TensorHandle(h); + handle_retvals[i] = h; } - } - *num_retvals = std::min(*num_retvals, outputs.size()); - for (int i = 0; i < *num_retvals; ++i) { - tensorflow::Device* d = IsCPU(device) ? nullptr : device; - if (d != nullptr && output_memory_types != nullptr && - (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { - d = nullptr; + tensorflow::EagerNode* node = new tensorflow::ExecuteNode( + id, &op->ctx->context, op->device, op->inputs, kernel, + maybe_stats.release(), output_dtypes, handle_retvals); + ctx->context.ExecutorAdd(node); + } else { + // Execute checks if retvals[i] is nullptr or not to figure if it needs to + // allocate it. + tensorflow::gtl::InlinedVector handle_retvals( + *num_retvals); + status->status = tensorflow::EagerExecute( + &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(), + handle_retvals.data(), *num_retvals); + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } - retvals[i] = new TFE_TensorHandle(outputs[i], d); } } +TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status) { + tensorflow::TensorHandle* handle; + status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + device_name, &handle); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); + } + return nullptr; +} + void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status) { @@ -895,46 +968,126 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function_def); + status->status = ctx->context.AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); + status->status = ctx->context.AddFunctionDef(function->fdef); +} + +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->context.SetShouldStoreMetadata(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + ctx->context.SetShouldStoreMetadata(false); } } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { - return new TFE_TensorHandle(t, nullptr); + return new TFE_TensorHandle(t, nullptr, nullptr); } const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( TFE_TensorHandle* h, TF_Status* status) { - if (h->d != nullptr) { + tensorflow::Device* d = nullptr; + tensorflow::Device* op_device = nullptr; + const tensorflow::Tensor* t = nullptr; + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); + if (!status->status.ok()) return nullptr; + if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( "TFE_TensorHandle is placed in device (not host) memory. Cannot return " "a tensorflow::Tensor"); return nullptr; } - return &h->t; + return t; } -void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { - ctx->should_store_metadata.store(true); +void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + TFE_ContextAsyncWait(ctx, status); + if (!status->status.ok()) return; + tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); + status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); + ctx->context.RunMetadataProto()->Clear(); } -void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - ctx->should_store_metadata.store(false); - ctx->run_metadata.Clear(); +namespace { +TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, + TF_Status* status) { + TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); + for (const auto& attr : func.attr()) { + if (TF_GetCode(status) != TF_OK) return nullptr; + SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + return func_op; } +} // namespace -void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, - TF_Status* status) { - tensorflow::mutex_lock ml(ctx->metadata_mu); - status->status = MessageToBuffer(ctx->run_metadata, buf); - ctx->run_metadata.Clear(); +namespace tensorflow { +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status) { + switch (default_value.value_case()) { + case tensorflow::AttrValue::kS: + TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + break; + case tensorflow::AttrValue::kI: + TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); + break; + case tensorflow::AttrValue::kF: + TFE_OpSetAttrFloat(op, attr_name, default_value.f()); + break; + case tensorflow::AttrValue::kB: + TFE_OpSetAttrBool(op, attr_name, default_value.b()); + break; + case tensorflow::AttrValue::kType: + TFE_OpSetAttrType(op, attr_name, + static_cast(default_value.type())); + break; + case tensorflow::AttrValue::kShape: { + const auto& tensor_shape = default_value.shape(); + if (tensor_shape.unknown_rank()) { + TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); + } else { + const auto num_dims = tensor_shape.dim_size(); + std::unique_ptr dims(new int64_t[num_dims]); + for (int i = 0; i < num_dims; ++i) { + dims[i] = tensor_shape.dim(i).size(); + } + TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); + } + } break; + case tensorflow::AttrValue::kFunc: { + const auto func_op = GetFunc(ctx, default_value.func(), status); + if (TF_GetCode(status) != TF_OK) return; + // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList + // require TFE_Op* and just convert it internally a NameAttrValue, so + // consider adding an overload to the C API to make this case easier. + TFE_OpSetAttrFunction(op, attr_name, func_op); + } break; + case tensorflow::AttrValue::kList: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kTensor: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::kPlaceholder: + TF_FALLTHROUGH_INTENDED; + case tensorflow::AttrValue::VALUE_NOT_SET: + TF_SetStatus( + status, TF_UNIMPLEMENTED, + tensorflow::strings::StrCat("Unable to get setfor default value: ", + default_value.DebugString()) + .data()); + } +} +} // namespace tensorflow + +TFE_Op::~TFE_Op() { + for (tensorflow::TensorHandle* h : inputs) { + h->Unref(); + } } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 7a321b54da343fd2b8912187bc620c1e7456db0c..c06ce84a8c578aa60dd626c24bd58098b78ae750 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -30,7 +30,7 @@ limitations under the License. #ifdef SWIG #define TF_CAPI_EXPORT #else -#if defined(COMPILER_MSVC) +#if defined(_WIN32) #ifdef TF_COMPILE_LIBRARY #define TF_CAPI_EXPORT __declspec(dllexport) #else @@ -38,7 +38,7 @@ limitations under the License. #endif // TF_COMPILE_LIBRARY #else #define TF_CAPI_EXPORT __attribute__((visibility("default"))) -#endif // COMPILER_MSVC +#endif // _WIN32 #endif // SWIG #ifdef __cplusplus @@ -65,14 +65,19 @@ typedef enum TFE_ContextDevicePlacementPolicy { TFE_DEVICE_PLACEMENT_EXPLICIT = 0, // Copy the tensor to the right device but log a warning. TFE_DEVICE_PLACEMENT_WARN = 1, - // Silently copy the tensor, which has a performance cost since the - // operation will be blocked till the copy completes. + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default placement + // policy. TFE_DEVICE_PLACEMENT_SILENT = 2, - // Default placement policy which silently copies int32 tensors but not other - // dtypes. + // Placement policy which silently copies int32 tensors but not other dtypes. TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; +// Sets the default execution mode (sync/async). Note that this can be +// overridden per thread using TFE_ContextSetAsyncForThread. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, + unsigned char async); + TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -108,6 +113,30 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(TFE_Context*); +// Overrides the execution mode (sync/async) for the current thread. +TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, + unsigned char async, + TF_Status* status); + +// Causes the calling thread to block till all ops dispatched in async mode +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*, + TF_Status* status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, @@ -117,13 +146,25 @@ typedef struct TFE_TensorHandle TFE_TensorHandle; TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status); +// Indicates that the caller will not be using `h` any more. TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); -TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, - int dim_index); + int dim_index, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( - TFE_TensorHandle* h); + TFE_TensorHandle* h, TF_Status* status); + +// This function will block till the operation that produces `h` has +// completed. The memory returned might alias the internal memory used by +// TensorFlow. Hence, callers should not mutate this memory (for example by +// modifying the memory region pointed to by TF_TensorData() on the returned +// TF_Tensor). TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); @@ -133,6 +174,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). +// If async execution is enabled, the copy may be enqueued and the call will +// return "non-ready" handle. Else, this function returns after the copy has +// been done. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); @@ -153,6 +197,7 @@ typedef struct TFE_Op TFE_Op; TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status); + TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, @@ -238,13 +283,21 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, int num_values); // Execute the operation defined by 'op' and return handles to computed -// tensors in 'retvals'. +// tensors in `retvals`. +// +// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. // -// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* -// and '*num_retvals' should be set to the size of this array. +// If async execution is enabled, the call may simply enqueue the execution +// and return "non-ready" handles in `retvals`. Note that any handles contained +// in 'op' should not be mutated till the kernel execution actually finishes. // -// On return, 'num_retvals' will be set to the actual number of outputs -// returned by the operation. +// For sync execution, if any of the inputs to `op` are not ready, this call +// will block till they become ready and then return when the kernel execution +// is done. +// TODO(agarwal): change num_retvals to int from int*. TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status); @@ -270,6 +323,8 @@ TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); // Populates the passed-in buffer with a serialized RunMetadata protocol buffer // containing any run metadata information accumulated so far and clears this // information. +// If async mode is enabled, this call blocks till all currently pending ops are +// done. TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 7b9f1db02ed9c53a280c7bd1284165cac4fb6353..05dc64f521735f944559392f470a37590e93f17c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include #include +#include #include #include #include @@ -28,82 +30,55 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" + struct TFE_ContextOptions { TF_SessionOptions session_options; - TFE_ContextDevicePlacementPolicy policy{ - TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; + // true if async execution is enabled. + bool async = false; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; }; struct TFE_Context { - explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) - : policy(opts.policy), - session(s), - rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), - pflr(new tensorflow::ProcessFunctionLibraryRuntime( - session->device_mgr, opts.session_options.options.env, - TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} - - const TFE_ContextDevicePlacementPolicy policy; - - // Note: we cannot use C++11 thread_local here as there is no concept of a - // thread-local-object-local variable in C++11. - tensorflow::mutex policy_map_mu; - std::unordered_map - thread_local_policies GUARDED_BY(policy_map_mu); - - // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* const session; - tensorflow::Rendezvous* const rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - const std::unique_ptr pflr; - - tensorflow::mutex cache_mu; - std::unordered_map - kernel_cache GUARDED_BY(cache_mu); - - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { - return pflr->GetFLR(d->name()); - } - - const std::vector& devices() { return session->devices; } - - // Whether we should compute RunMetadata. - std::atomic should_store_metadata{false}; - tensorflow::mutex metadata_mu; - tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); + explicit TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, + bool async, + std::unique_ptr device_mgr, + tensorflow::Rendezvous* rendezvous) + : context(opts, + static_cast( + default_policy), + async, std::move(device_mgr), rendezvous) {} + + tensorflow::EagerContext context; }; struct TFE_TensorHandle { - TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) - : t(t), d(d) {} - - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, + tensorflow::Device* op_device) + : handle(new tensorflow::TensorHandle(t, d, op_device, nullptr)) {} + + TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, + tensorflow::EagerContext* ctx) + : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} + + TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} + + tensorflow::TensorHandle* handle; }; struct TFE_Op { @@ -112,16 +87,24 @@ struct TFE_Op { TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + ~TFE_Op(); + bool const is_function() const { return attr_types == nullptr; } TFE_Context* ctx; // Must outlive the TFE_Op. const tensorflow::string name; tensorflow::AttrBuilder attrs; const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; + tensorflow::gtl::InlinedVector inputs; tensorflow::Device* device; bool use_xla = false; }; +namespace tensorflow { +// Set an AttrValue on the op. Doesn't handle the list types. +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status); +} // namespace tensorflow + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4a3ecbc0abb16296a84c0d2184dc3fc9f7f3ebb4..701175e4943d1d23532fe595319f67711316ed4d 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -29,6 +29,20 @@ using tensorflow::string; namespace { +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestMatrixTensorHandle() { int64_t dims[] = {2, 2}; float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -43,6 +57,20 @@ TFE_TensorHandle* TestMatrixTensorHandle() { return th; } +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { TF_Status* status = TF_NewStatus(); @@ -139,10 +167,12 @@ void BM_InitOp(int iters) { } BENCHMARK(BM_InitOp); -void BM_Execute(int iters) { +void BM_Execute(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -156,6 +186,9 @@ void BM_Execute(int iters) { TFE_Execute(matmul, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); @@ -163,7 +196,7 @@ void BM_Execute(int iters) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_Execute); +BENCHMARK(BM_Execute)->Arg(0)->Arg(1); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); @@ -205,10 +238,11 @@ TEST(CAPI, TensorHandle) { TFE_DeleteTensorHandle(h); } -TEST(CAPI, TensorHandleCopyBetweenDevices) { +void TensorHandleCopyBetweenDevices(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -274,10 +308,56 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { +TEST(CAPI, TensorHandleCopyBetweenDevices) { + TensorHandleCopyBetweenDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesAsync) { + TensorHandleCopyBetweenDevices(true); +} + +void TensorHandleCopyBetweenDevicesError(bool async) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + const char* kErrorDevice = "NoSuchDevice:0"; + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kErrorDevice, status.get()); + EXPECT_NE(TF_OK, TF_GetCode(status.get())); + const char* msg = "NoSuchDevice:0 unknown device"; + EXPECT_TRUE(strstr(TF_Message(status.get()), msg) != nullptr) + << TF_Message(status.get()); + TF_SetStatus(status.get(), TF_OK, ""); + const char* kCPUDevice = "CPU:0"; + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())); + TFE_DeleteTensorHandle(hcopy); + TFE_DeleteTensorHandle(hcpu); + if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); + TFE_DeleteContext(ctx, status.get()); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesError) { + TensorHandleCopyBetweenDevicesError(false); +} + +TEST(CAPI, TensorHandleCopyBetweenDevicesErrorAsync) { + TensorHandleCopyBetweenDevicesError(true); +} + +void TensorHandleCopyBetweenTwoGPUDevices(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -332,11 +412,20 @@ TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopy) { +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { + TensorHandleCopyBetweenTwoGPUDevices(false); +} + +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevicesAsync) { + TensorHandleCopyBetweenTwoGPUDevices(true); +} + +void TensorHandleSilentCopy(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -366,14 +455,20 @@ TEST(CAPI, TensorHandleSilentCopy) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopyLocal) { +TEST(CAPI, TensorHandleSilentCopy) { TensorHandleSilentCopy(false); } +TEST(CAPI, TensorHandleSilentCopyAsync) { TensorHandleSilentCopy(true); } + +void TensorHandleSilentCopyLocal(bool async) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status.get()); @@ -407,11 +502,17 @@ TEST(CAPI, TensorHandleSilentCopyLocal) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); + TFE_ContextAsyncWait(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } +TEST(CAPI, TensorHandleSilentCopyLocalAsync) { + TensorHandleSilentCopyLocal(true); +} -TEST(CAPI, SetAndGetOpDevices) { +void SetAndGetOpDevices(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -442,27 +543,28 @@ TEST(CAPI, SetAndGetOpDevices) { TF_DeleteStatus(status); } -TEST(CAPI, Execute_MatMul_CPU) { +void Execute_MatMul_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; + int num_retvals = 2; TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(1, num_retvals); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -474,7 +576,107 @@ TEST(CAPI, Execute_MatMul_CPU) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_CPU) { Execute_MatMul_CPU(false); } +TEST(CAPI, Execute_MatMul_CPUAsync) { Execute_MatMul_CPU(true); } + +void Execute_MatMul_CPU_Runtime_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", + status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Op* matmul2 = MatMulOp(ctx, m1, m1); + TFE_OpSetDevice(matmul2, "/job:localhost/replica:0/task:0/device:CPU:0", + status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + TFE_DeleteOp(matmul); + if (!async) { + EXPECT_NE(TF_OK, TF_GetCode(status)); + } else { + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + EXPECT_EQ(nullptr, t); + const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) + << TF_Message(status); + // Since error is not cleared, the following copy with correct device will + // still fail. + TF_SetStatus(status, TF_OK, ""); + TFE_DeleteTensorHandle(retvals[0]); + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_ContextAsyncClearError(ctx); + TFE_ContextAsyncWait(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + } + // Following works in async mode since TFE_ContextAsyncClearError was called. + TF_SetStatus(status, TF_OK, ""); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + retvals[0] = nullptr; + TFE_Execute(matmul2, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteTensor(t); + TFE_DeleteOp(matmul2); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_Error) { + Execute_MatMul_CPU_Runtime_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Runtime_ErrorAsync) { + Execute_MatMul_CPU_Runtime_Error(true); +} + +void Execute_MatMul_CPU_Type_Error(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m1 = TestMatrixTensorHandle(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m1, m2); + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_NE(TF_OK, TF_GetCode(status)); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m1); + TFE_DeleteTensorHandle(m2); + if (retvals[0] != nullptr) { + TFE_DeleteTensorHandle(retvals[0]); + } + TFE_DeleteContext(ctx, status); + TF_DeleteStatus(status); +} +TEST(CAPI, Execute_MatMul_CPU_Type_Error) { + Execute_MatMul_CPU_Type_Error(false); +} +TEST(CAPI, Execute_MatMul_CPU_Type_ErrorAsync) { + Execute_MatMul_CPU_Type_Error(true); +} TEST(CAPI, Execute_Min_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -485,33 +687,34 @@ TEST(CAPI, Execute_Min_CPU) { TFE_TensorHandle* input = TestMatrixTensorHandle(); TFE_TensorHandle* axis = TestAxisTensorHandle(); TFE_Op* minOp = MinOp(ctx, input, axis); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); - TFE_DeleteTensorHandle(retvals[0]); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); float output[2] = {0}; EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } #ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Execute_MatMul_XLA_CPU) { +void Execute_MatMul_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -521,15 +724,14 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { TFE_OpSetXLACompilation(matmul, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); // Running a primitive TF operator via XLA is not yet supported. ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(1, num_retvals); @@ -545,13 +747,16 @@ TEST(CAPI, Execute_MatMul_XLA_CPU) { EXPECT_EQ(10, product[1]); EXPECT_EQ(15, product[2]); EXPECT_EQ(22, product[3]); - + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_MatMul_XLA_CPU) { Execute_MatMul_XLA_CPU(false); } +TEST(CAPI, Execute_MatMul_XLA_CPUAsync) { Execute_MatMul_XLA_CPU(true); } -TEST(CAPI, Execute_Min_XLA_CPU) { +void Execute_Min_XLA_CPU(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -562,14 +767,13 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TFE_OpSetXLACompilation(minOp, true); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(minOp, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(minOp); TFE_DeleteTensorHandle(input); TFE_DeleteTensorHandle(axis); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); @@ -582,13 +786,17 @@ TEST(CAPI, Execute_Min_XLA_CPU) { TF_DeleteTensor(t); EXPECT_EQ(1, output[0]); EXPECT_EQ(3, output[1]); + TFE_DeleteContext(ctx, status); TF_DeleteStatus(status); } +TEST(CAPI, Execute_Min_XLA_CPU) { Execute_Min_XLA_CPU(false); } +TEST(CAPI, Execute_Min_XLA_CPUAsync) { Execute_Min_XLA_CPU(true); } #endif // TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, ExecuteWithTracing) { +void ExecuteWithTracing(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); TFE_ContextEnableRunMetadata(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -596,8 +804,8 @@ TEST(CAPI, ExecuteWithTracing) { TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); - TFE_TensorHandle* retvals[2] = {nullptr}; - int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_TensorHandle* retvals[1] = {nullptr}; + int num_retvals = 1; TFE_Execute(matmul, &retvals[0], &num_retvals, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); @@ -609,12 +817,12 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_TRUE( rm.ParseFromString({reinterpret_cast(b->data), b->length})); TF_DeleteBuffer(b); - TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float product[4] = {0}; EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); @@ -626,6 +834,8 @@ TEST(CAPI, ExecuteWithTracing) { EXPECT_EQ(22, product[3]); TF_DeleteStatus(status); } +TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); } +TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); } TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. @@ -657,32 +867,37 @@ TEST(CAPI, Function_ident_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -719,35 +934,40 @@ TEST(CAPI, Function_ident_XLA_CPU) { ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteFunction(fn); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); + for (bool async : {false, true, false}) { + TFE_ContextSetAsyncForThread(ctx, static_cast(async), + status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } TFE_DeleteContext(ctx, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); TF_DeleteStatus(status); @@ -788,9 +1008,10 @@ string MatMulFunction() { return def.SerializeAsString(); } -TEST(CAPI, FunctionDefAndExecute) { +void FunctionDefAndExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -827,11 +1048,16 @@ TEST(CAPI, FunctionDefAndExecute) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } +TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); } +TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); } -void BM_ExecuteFunction(int iters) { +void BM_ExecuteFunction(int iters, int async) { tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync" + : "ExecuteFunction"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -853,6 +1079,9 @@ void BM_ExecuteFunction(int iters) { TFE_Execute(matmul, &retval[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } + if (async) { + TFE_ContextAsyncWait(ctx, status); + } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); @@ -860,7 +1089,7 @@ void BM_ExecuteFunction(int iters) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); } -BENCHMARK(BM_ExecuteFunction); +BENCHMARK(BM_ExecuteFunction)->Arg(0)->Arg(1); TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, TF_Status* status) { @@ -932,7 +1161,8 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle)); - EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle)); + EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status)); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float value = 0.0f; TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -974,7 +1204,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(1, num_retvals); CHECK(h); CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); - CHECK_EQ(0, TFE_TensorHandleNumDims(h)); + CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; } tensorflow::testing::StopTiming(); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index f77a937f1ffc2d146224cb3191a5ca127daefc22..e6c51ab17a867a0697f15d7683d8ca52c062035d 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -41,17 +42,26 @@ const uint32 kIsList = 1U << 31; } // namespace +Status OpDefForOp(const char* op_name, const OpDef** op_def) { + const OpRegistrationData* op_reg_data = nullptr; + Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (s.ok()) { + *op_def = &op_reg_data->op_def; + } + return s; +} + Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { mutex_lock l(g_op_name_to_attr_type_map_lock); *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); if (*out != nullptr) return Status::OK(); - const OpRegistrationData* op_reg_data = nullptr; - Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name, &op_def); if (!s.ok()) return s; std::unique_ptr m(new AttrTypeMap); // TODO(agarwal): Avoid having to create this "registry" at runtime, // perhaps can be done at op registration time? - for (const auto& attr : op_reg_data->op_def.attr()) { + for (const auto& attr : op_def->attr()) { string type = attr.type(); const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); if (is_list) { @@ -86,22 +96,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, - TF_AttrType* out, unsigned char* is_list) { - auto* t = gtl::FindOrNull(m, attr_name); - if (t == nullptr) { - return errors::InvalidArgument("Attribute '", attr_name, - "' does not exist for this operation"); - } - *out = static_cast(*t & ~kIsList); - if (*t & kIsList) { - *is_list = 1; - } else { - *is_list = 0; - } - return Status::OK(); -} - #define DEFINE_SET_ATTR(value_type, value_field) \ template <> \ AttrBuilder& AttrBuilder::Set(StringPiece attr_name, value_type&& value) { \ @@ -159,6 +153,22 @@ const NodeDef& AttrBuilder::BuildNodeDef() { return *node_def_; } +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list) { + auto* t = gtl::FindOrNull(m, attr_name); + if (t == nullptr) { + return errors::InvalidArgument("Attribute '", attr_name, + "' does not exist for this operation"); + } + *out = static_cast(*t & ~kIsList); + if (*t & kIsList) { + *is_list = 1; + } else { + *is_list = 0; + } + return Status::OK(); +} + namespace { inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, const tensorflow::Fprint128& b) { @@ -174,8 +184,7 @@ void CombineUnordered(const tensorflow::Fprint128& a, inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, const tensorflow::Fprint128& b) { - // TODO(agarwal): avoid ToString(). - tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); + tensorflow::Fprint128 a = tensorflow::Fingerprint128(s); return FingerprintCat128(a, b); } @@ -203,10 +212,8 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const { if (node_def_finalized_) return f; } for (const auto& p : string_attrs_) { - // TODO(agarwal): avoid ToString(). - CombineUnordered(CacheKeyHelper(p.first, tensorflow::Fingerprint128( - p.second.ToString())), - &f); + CombineUnordered( + CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f); } for (const auto& p : int_attrs_) { CombineUnordered(CacheKeyHelper(p.first, static_cast(p.second)), @@ -236,93 +243,4 @@ void AttrBuilder::MayBeInitializeNodeDef() { } } -// static -Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = CreateOpKernel(device->device_type().c_str(), device, - device->GetAllocator(AllocatorAttributes()), - nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); - out->device_ = device; - out->kernel_.reset(k); - out->flib_ = nullptr; - return s; -} - -// static -Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = flib->CreateKernel(ndef, &k); - out->device_ = flib->device(); - out->kernel_.reset(k); - out->flib_ = flib; - return s; -} - -Status KernelAndDevice::Run(std::vector* input_tensors, - std::vector* output_tensors, - NodeExecStats* stats) { - gtl::InlinedVector inputs; - for (Tensor& t : *input_tensors) { - inputs.push_back(TensorValue(&t)); - } - - std::vector out_attrs(kernel_->num_outputs()); - for (size_t i = 0; i < out_attrs.size(); ++i) { - out_attrs[i].set_on_host(kernel_->output_memory_types()[i] == - tensorflow::HOST_MEMORY); - } - - OpKernelContext::Params params; - params.device = device_; - params.frame_iter = FrameAndIter(0, 0); - params.inputs = &inputs; - params.op_kernel = kernel_.get(); - params.resource_manager = device_->resource_manager(); - params.output_attr_array = gtl::vector_as_array(&out_attrs); - params.function_library = flib_; - params.slice_reader_cache = &slice_reader_cache_; - params.rendezvous = rendez_; - if (stats != nullptr) { - params.track_allocations = true; - } - // TODO(apassos): use a thread pool. - std::function)> runner = - [](std::function f) { f(); }; - params.runner = &runner; - - OpKernelContext context(¶ms); - device_->Compute(kernel_.get(), &context); - if (!context.status().ok()) return context.status(); - - output_tensors->clear(); - for (int i = 0; i < context.num_outputs(); ++i) { - output_tensors->push_back(Tensor(*context.mutable_output(i))); - } - if (stats != nullptr) { - for (const auto& allocator_pair : context.wrapped_allocators()) { - AllocatorMemoryUsed* memory = stats->add_memory(); - memory->set_allocator_name(allocator_pair.first->Name()); - auto sizes = allocator_pair.second->GetSizes(); - memory->set_total_bytes(std::get<0>(sizes)); - memory->set_peak_bytes(std::get<1>(sizes)); - memory->set_live_bytes(std::get<2>(sizes)); - - AllocatorStats allocator_stats; - allocator_pair.first->GetStats(&allocator_stats); - memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); - allocator_pair.second->GetRecordsAndUnRef(); - } - auto* ms = stats->mutable_memory_stats(); - ms->set_temp_memory_size(context.temp_memory_allocated()); - for (const auto& alloc_id : context.persistent_alloc_ids()) { - ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); - } - - ms->set_persistent_memory_size(context.persistent_memory_allocated()); - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 4d20b5244a46fcde2eed0a429dced2a77b86aedd..929b1b8296faf61c11c68af06ffc4ca3770ae929 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -39,9 +40,16 @@ namespace tensorflow { // represent the TF_AttrType type of the values in the list. typedef std::unordered_map AttrTypeMap; +// Look up OpDef for `op_name`. +Status OpDefForOp(const char* op_name, const OpDef** op_def); + // Returns the AttrTypeMap for the TensorFlow operation named op_name. Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); +// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list); + // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); @@ -146,47 +154,6 @@ template <> AttrBuilder& AttrBuilder::Set(StringPiece attr_name, tensorflow::DataType&& value); -// KernelAndDevice encapsulates an instantiated kernel and the device it is on. -// -// Also see: -// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h -// and -// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h -class KernelAndDevice { - public: - // Populates 'out' with a kernel appropriate for 'ndef'. - // - // The provided FunctionLibraryRuntime MUST outlive all calls to - // Run() on the returned KernelAndDevice. - // - // TODO(ashankar): Figure out thread-safety concerns around - // FunctionLibraryRuntime (in particular, how the underlying - // FunctionLibraryDefinition might be mutated by another thread as new - // functions are registered with it). Conservatively, thread-safe usage of - // the FunctionLibraryRuntime is pushed on to the caller (see locking in - // c_api.cc). - static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, - KernelAndDevice* out); - // TODO(ashankar): Remove this - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - - KernelAndDevice(tensorflow::Rendezvous* rendez) - : device_(nullptr), flib_(nullptr), rendez_(rendez) {} - - // TODO(ashankar): Handle list-valued inputs. - Status Run(std::vector* inputs, std::vector* outputs, - NodeExecStats* stats); - - const OpKernel* kernel() const { return kernel_.get(); } - - private: - std::unique_ptr kernel_; - Device* device_; - FunctionLibraryRuntime* flib_; - checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; - Rendezvous* rendez_; -}; } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 643153058ce3d6f0c88dd23a0dec4c6eff060319..27ebeb0508844ee1ee89e0733b66f6ed129b7757 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -33,27 +33,6 @@ limitations under the License. namespace tensorflow { namespace { -class TestEnv { - public: - TestEnv() : flib_def_(OpRegistry::Global(), {}) { - Device* device = - DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); - device_mgr_.reset(new DeviceMgr({device})); - flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), - device, TF_GRAPH_DEF_VERSION, - &flib_def_, {}, nullptr); - } - - FunctionLibraryRuntime* function_library_runtime() const { - return flib_runtime_.get(); - } - - private: - FunctionLibraryDefinition flib_def_; - std::unique_ptr device_mgr_; - std::unique_ptr flib_runtime_; -}; - TEST(AttrTypeMap, Lookup) { const AttrTypeMap* m = nullptr; Status s = AttrTypeMapForOp("ThisOpCannotPossiblyExist", &m); @@ -79,113 +58,5 @@ TEST(AttrTypeMap, Lookup) { EXPECT_NE(is_list, 0); } -TEST(KernelAndDevice, Run) { - Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); - std::vector inputs; - inputs.push_back(t); - inputs.push_back(t); - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(inputs.size()) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice kernel(nullptr); - Status s = - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); - ASSERT_TRUE(s.ok()) << s; - std::vector outputs; - s = kernel.Run(&inputs, &outputs, nullptr); - ASSERT_TRUE(s.ok()) << s; - ASSERT_EQ(1, outputs.size()); - const Tensor& out = outputs[0]; - EXPECT_EQ(7, out.matrix()(0, 0)); - EXPECT_EQ(10, out.matrix()(0, 1)); - EXPECT_EQ(15, out.matrix()(1, 0)); - EXPECT_EQ(22, out.matrix()(1, 1)); -} - -void BM_CreateGraph(int iters) { - for (int i = 0; i < iters; ++i) { - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - TF_CHECK_OK(root.status()); - } -} -BENCHMARK(BM_CreateGraph); - -void BM_RunGraph(int iters) { - tensorflow::testing::StopTiming(); - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - SessionOptions opts; - opts.config.set_inter_op_parallelism_threads(1); - opts.config.set_intra_op_parallelism_threads(1); - ClientSession sess(root, opts); - std::vector outputs; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - outputs.clear(); - TF_CHECK_OK(sess.Run({M}, &outputs)); - } -} -BENCHMARK(BM_RunGraph); - -void BM_CreateAndDestroySession(int iters) { - tensorflow::testing::StopTiming(); - Scope root = Scope::NewRootScope(); - auto C = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}}); - auto M = ops::MatMul(root, C, C); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - ClientSession sess(root); - } -} -BENCHMARK(BM_CreateAndDestroySession); - -void BM_KernelAndDeviceInit(int iters) { - tensorflow::testing::StopTiming(); - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(2) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice k(nullptr); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); - } -} -BENCHMARK(BM_KernelAndDeviceInit); - -void BM_KernelAndDeviceRun(int iters) { - tensorflow::testing::StopTiming(); - Tensor t(Input({{1.0f, 2.0f}, {3.0f, 4.0f}}).tensor()); - std::vector inputs; - inputs.push_back(t); - inputs.push_back(t); - std::vector outputs; - NodeDef ndef(AttrBuilder("MatMul") - .Set("T", DT_FLOAT) - .Set("transpose_a", false) - .Set("transpose_b", false) - .NumInputs(inputs.size()) - .BuildNodeDef()); - TestEnv env; - KernelAndDevice kernel(nullptr); - TF_CHECK_OK( - KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); - } -} -BENCHMARK(BM_KernelAndDeviceRun); } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index bdb0815d6b68444ec1c89b835d563db20ce4d8a1..97c323b87228039ba10f4ed5e434aa83621b1220 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -152,6 +152,8 @@ class GradientTape { gtl::ArraySlice output_gradients, std::vector* result); + bool IsPersistent() const { return persistent_; } + private: TensorTape tensor_tape_; OpTape op_tape_; @@ -599,23 +601,28 @@ Status GradientTape::ComputeGradient( } CHECK(state.op_tape.empty()); result->reserve(source_tensor_ids.size()); + gtl::FlatSet used_gradient_ids(source_tensor_ids.size()); for (auto is : source_tensor_ids) { auto grad_it = gradients.find(is); if (grad_it == gradients.end()) { result->push_back(nullptr); } else { - if (grad_it->second.size() == 1) { - result->push_back(grad_it->second[0]); - } else { - result->push_back(vspace.AggregateGradients(grad_it->second)); + if (grad_it->second.size() > 1) { + Gradient* grad = vspace.AggregateGradients(grad_it->second); + grad_it->second.clear(); + grad_it->second.push_back(grad); } - gradients.erase(grad_it); + result->push_back(grad_it->second[0]); + used_gradient_ids.insert(is); } } - VLOG(1) << "Final gradients size: " << gradients.size(); + VLOG(1) << "Final gradients size: " + << gradients.size() - used_gradient_ids.size(); for (auto grad_pair : gradients) { - for (const auto& g : grad_pair.second) { - vspace.DeleteGradient(g); + if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } } } return Status::OK(); diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68..93155998b86d59ec78c7ff25f146b8e3c8eac380 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/python_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/python/framework/cpp_shape_inference.pb.h" namespace tensorflow { @@ -99,4 +100,39 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { } } +void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { + mutex_lock l(graph->mu); + graph->refiner.set_require_shape_inference_fns(require); +} + +void ExtendSession(TF_Session* session, TF_Status* status) { + ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; +} + +std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + } + } + string result; + handle_data.SerializeToString(&result); + return result; +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index aa9d9e06b28c54cb8869eb547d36ee3cb0d4e6b8..2d4c8cd9ed7bc926f448dab1f6b50ed74179ea14 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_PYTHON_API_H_ #define TENSORFLOW_C_PYTHON_API_H_ +#include + #include "tensorflow/c/c_api.h" // These functions can be removed without notice. They exist to facilitate some @@ -37,6 +39,25 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); +// Sets whether ops missing a shape inference function should trigger an +// error. The default is true. +void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); + +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. +void ExtendSession(TF_Session* session, TF_Status* status); + +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource tensor, or otherwise returns the empty string. +// TODO(b/74620627): remove when _USE_C_SHAPES is removed +std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/testdata/tf_record b/tensorflow/c/testdata/tf_record new file mode 100644 index 0000000000000000000000000000000000000000..6e16076bfb79ad8151952e96567565e8820b0f5b Binary files /dev/null and b/tensorflow/c/testdata/tf_record differ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 9060c19e9d2cf965c2b9be07be07c42017da45a8..079e063d3e3fbdaf833e9031f5f9438853c14099 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -620,18 +620,6 @@ tf_cc_binary( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - cc_library( name = "queue_runner", srcs = ["training/queue_runner.cc"], diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a40ad1ffc3b262840e6ca0043139b1b61e04510d..d73121c7b701ec06c03836d1a765f4b35d88fe92 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" @@ -697,7 +698,8 @@ string OpInfo::GetOpAttrStruct() const { attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); - strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); + strings::StrAppend(&setters, " TF_MUST_USE_RESULT Attrs ", attr_func_def, + " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), "_ = x;\n"); diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 1e0f2d241bb350897a840dda90d6d0c009b1daad..5d9dfd95a5538ae0f3d2d111a1f989552c3363b8 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -61,12 +62,12 @@ op { )"; void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(s.contains(expected)) + EXPECT_TRUE(str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { - EXPECT_FALSE(s.contains(expected)) + EXPECT_FALSE(str_util::StrContains(s, expected)) << "'" << s << "' contains '" << expected << "'"; } diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 71642492627422e09c19b7bcb4dc522846cf08b1..c143b978338815ebc7134eb0a07867c5d8b13dca 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { @@ -218,7 +219,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { for (const string& entry : node_constraints) { StringPiece s(entry); - if (s.Consume(kColocationGroupPrefix)) { + if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { current_constraints.insert(s.ToString()); } } diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index 0734075fc6144d7c9f4fdb48c5e097faa58b8355..81870a0efa309ae6dbd5cc05a5dbe8c3e2d437c8 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -72,9 +72,9 @@ Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, }; // Body function that adds one to input. - BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, - const std::vector& inputs, - std::vector* outputs) { + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { DCHECK_EQ(inputs.size(), 1); outputs->emplace_back(ops::Add(scope, inputs[0], 1)); return scope.status(); diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d..0cb3132e94e381f672d69aefe4a199d2b590830c 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -48,8 +48,8 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad); Status LogSoftmaxGrad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { + const std::vector& grad_inputs, + std::vector* grad_outputs) { auto softmax = Exp(scope, op.output(0)); auto sum = Sum(scope, grad_inputs[0], {1}, Sum::KeepDims(true)); auto mul = Mul(scope, sum, softmax); @@ -107,11 +107,10 @@ Status BiasAddGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { string data_format; - BiasAddGrad::Attrs input_attrs; TF_RETURN_IF_ERROR( GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format)); - input_attrs.DataFormat(data_format); - auto dx_1 = BiasAddGrad(scope, grad_inputs[0], input_attrs); + auto dx_1 = + BiasAddGrad(scope, grad_inputs[0], BiasAddGrad::DataFormat(data_format)); grad_outputs->push_back(Identity(scope, grad_inputs[0])); grad_outputs->push_back(dx_1); return scope.status(); @@ -130,19 +129,16 @@ Status Conv2DGrad(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "use_cudnn_on_gpu", &use_cudnn_on_gpu)); - Conv2DBackpropInput::Attrs input_attrs; - input_attrs.DataFormat(data_format); - input_attrs.UseCudnnOnGpu(use_cudnn_on_gpu); - auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), - op.input(1), grad_inputs[0], - strides, padding, input_attrs); + auto dx_1 = Conv2DBackpropInput(scope, Shape(scope, op.input(0)), op.input(1), + grad_inputs[0], strides, padding, + Conv2DBackpropInput::DataFormat(data_format) + .UseCudnnOnGpu(use_cudnn_on_gpu)); grad_outputs->push_back(dx_1); - Conv2DBackpropFilter::Attrs filter_attrs; - filter_attrs.DataFormat(data_format); - filter_attrs.UseCudnnOnGpu(use_cudnn_on_gpu); - auto dx_2 = Conv2DBackpropFilter(scope, op.input(0), - Shape(scope, op.input(1)), grad_inputs[0], - strides, padding, filter_attrs); + auto dx_2 = + Conv2DBackpropFilter(scope, op.input(0), Shape(scope, op.input(1)), + grad_inputs[0], strides, padding, + Conv2DBackpropFilter::DataFormat(data_format) + .UseCudnnOnGpu(use_cudnn_on_gpu)); grad_outputs->push_back(dx_2); return scope.status(); } @@ -160,13 +156,9 @@ Status MaxPoolGradHelper(const Scope& scope, const Operation& op, TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); - internal::MaxPoolGrad::Attrs grad_attrs; - grad_attrs.DataFormat(data_format); - auto dx = internal::MaxPoolGrad(scope, op.input(0), - op.output(0), - grad_inputs[0], - ksize, strides, - padding, grad_attrs); + auto dx = internal::MaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], ksize, strides, padding, + internal::MaxPoolGrad::DataFormat(data_format)); grad_outputs->push_back(dx); return scope.status(); } @@ -180,15 +172,9 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, auto attrs = op.output(0).node()->attrs(); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); - MaxPoolGradV2::Attrs grad_attrs; - grad_attrs.DataFormat(data_format); - auto dx = MaxPoolGradV2(scope, op.input(0), - op.output(0), - grad_inputs[0], - op.input(1), - op.input(2), - padding, - grad_attrs); + auto dx = MaxPoolGradV2(scope, op.input(0), op.output(0), grad_inputs[0], + op.input(1), op.input(2), padding, + MaxPoolGradV2::DataFormat(data_format)); grad_outputs->push_back(dx); grad_outputs->push_back(NoGradient()); grad_outputs->push_back(NoGradient()); @@ -196,13 +182,74 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + MaxPool3DGrad::Attrs grad_attrs; + auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); + +Status AvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + internal::AvgPoolGrad::Attrs grad_attrs; + auto dx = + internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); + +Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + AvgPool3DGrad::Attrs grad_attrs; + auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, + grad_attrs.DataFormat(data_format)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper); + Status LRNGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, - std::vector* grad_outputs){ - internal::LRNGrad::Attrs grad_attrs; - - auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), - grad_attrs); + std::vector* grad_outputs) { + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0)); grad_outputs->push_back(dx); return scope.status(); } diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 0cfe5f6e3c49f7c4a3cafbf48ff4e54a0ffd0d47..c4eba7ecb017fe4628140d75a63bc7f0f09deb7f 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -31,8 +31,11 @@ using ops::Elu; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; +using ops::AvgPool; +using ops::AvgPool3D; using ops::MaxPool; using ops::MaxPoolV2; +using ops::MaxPool3D; using ops::Placeholder; using ops::Relu; using ops::Relu6; @@ -70,9 +73,9 @@ class NNGradTest : public ::testing::Test { // Sets tensor with random values, ensuring that the max value is largest by // a reasonable amount. - // This is an issue for MaxPool and MaxPoolV2, in which perturbations by the - // numeric gradient computation in the gradient checker can change the max - // value if values are too close together. + // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which + // perturbations by the numeric gradient computation in the gradient checker + // can change the max value if values are too close together. template void SetRandomValuesWithBumpedMax(Tensor* tensor) { auto tensor_flat = tensor->flat(); @@ -203,6 +206,41 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, MaxPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one MaxPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesWithBumpedMax(&x_init_value); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NNGradTest, AvgPoolGradHelper) { + TensorShape x_shape({1, 2, 2, 1}); + TensorShape y_shape({1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool. + const std::vector ksize{1, 2, 2, 1}; + const std::vector strides{1, 2, 2, 1}; + auto y = AvgPool(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + +TEST_F(NNGradTest, AvgPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(NNGradTest, LRN){ TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index 00799526fce572e7bb80199ccb8ce1cc89874031..cf65fe1ab99b49207a64e86310178141b30d07d7 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -9,6 +9,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], + tags = [ + "noguitar", # b/77649654 + ], deps = [ ":profiler", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index 6077c45c5854fd5812ccb7c91522f93ed4e54883..64edbb5766c3604fbe0f15c2299843718381aa3f 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -61,18 +61,18 @@ class Profiler { /// Adds tracing information `run_meta` to profiler. A `run_meta` is /// generated by a TensorFlow session run call. `step` is the key /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// `step` in `options` to selectively profile the corresponding `run_meta`. /// Multiple different `run_meta` can be keyed by the same `step` in order /// to group them together. void AddStep(int64 step, const RunMetadata& run_meta); /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are contected by the op inputs/outputs. + /// Each node is an op and the nodes are connected by the op inputs/outputs. GraphNodeProto ProfileGraph(const Options& options); /// Profiles the model by organizing nodes in name scope structure. /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a filesystem tree. + /// scope, similar to a file system tree. /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. GraphNodeProto ProfileNameScope(const Options& options); diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index d29ad3ebcbe29087d5572b51c7713e0c98d0d840..06a3be18e08f611d3ecf9804908d791d15fdab13 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -94,18 +94,3 @@ filegroup( "testdata/half_plus_two/**", ]), ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 4c64d2cfe3c10e6c7ed82a2d72460a0b34283bb2..72b8bc18710b0ee77cb01ed3ad0c2abb5183efb2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -133,9 +134,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied " - "tags: { missing-tag }")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: { missing-tag }")) << st.error_message(); } @@ -149,9 +150,9 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe, "missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags: ")) + EXPECT_TRUE(str_util::StrContains( + st.error_message(), + "Could not find meta graph def matching supplied tags: ")) << st.error_message(); } @@ -169,7 +170,7 @@ TEST_F(LoaderTest, SessionCreationFailure) { Status st = LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagServe}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + EXPECT_TRUE(str_util::StrContains(st.error_message(), kInvalidTarget)) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD index f5fbc75edcba9d5ae9ef7432de224df766bcab9e..6f04ebdc55cda329527c95f62efc37c8dfbb4ae5 100644 --- a/tensorflow/cc/saved_model/python/BUILD +++ b/tensorflow/cc/saved_model/python/BUILD @@ -7,18 +7,6 @@ package( default_visibility = ["//visibility:public"], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc") tf_py_clif_cc( diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD index 97f66e79b8ad9f383b22f56e9385fc6d2080e1f8..6f1c87354076565af22f7ba0610a5c6bb999d25c 100644 --- a/tensorflow/cc/tools/BUILD +++ b/tensorflow/cc/tools/BUILD @@ -32,6 +32,7 @@ tf_cc_test( deps = [ ":freeze_saved_model", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework_internal", "//tensorflow/core:protos_all_cc", @@ -40,18 +41,3 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) - -# ----------------------------------------------------------------------------- -# Google-internal targets. - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index ddf372cdef21e1b3892c9a03714478d5a5785517..4ddddcb5863c9ffb1e5367db750b0d2ffd29cd5e 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -75,16 +75,13 @@ void GetNodeNameToNodeDefMap( // variable nodes to convert. void GetReachableNodesAndVariables( GraphDef* graph_def, const std::unordered_set& outputs, + const std::unordered_map& name_to_node_map, std::unordered_set* reachable_node_names, std::unordered_set* variable_node_names) { // TODO(suharshs): Add support for ResourceVariables. static const std::unordered_set* kVariableTypes = - new std::unordered_set({"Variable", "VariableV2"}); - // name_to_node_map is needed to get the inputs from the NodeDef corresponding - // the a string node name. These inputs are used when doing our backwards - // traversal. - std::unordered_map name_to_node_map; - GetNodeNameToNodeDefMap(graph_def, &name_to_node_map); + new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); + std::queue nodes_to_visit; for (const string& tensor_name : outputs) { // We need to strip off the tensor part to get the node name. @@ -99,7 +96,7 @@ void GetReachableNodesAndVariables( continue; } reachable_node_names->insert(node_name); - NodeDef* node = name_to_node_map[node_name]; + NodeDef* node = name_to_node_map.at(node_name); if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } @@ -111,7 +108,9 @@ void GetReachableNodesAndVariables( // Gets a map from variable name to variable value. Status GetVariableNameToTensorMap( - Session* session, std::unordered_set variable_names_set, + Session* session, + const std::unordered_map& name_to_node_map, + std::unordered_set variable_names_set, std::unordered_map* variable_name_to_value_map) { if (variable_names_set.empty()) { return Status::OK(); @@ -120,8 +119,14 @@ Status GetVariableNameToTensorMap( std::vector tensor_names; for (const string& node_name : variable_names_set) { variable_names.push_back(node_name); - // We need to run tensors, so append ":0". - tensor_names.push_back(node_name + ":0"); + NodeDef* node_def = name_to_node_map.at(node_name); + if (node_def->op() == "VarHandleOp") { + // If this is a resource variable, we have to run the corresponding + // ReadVariableOp. + tensor_names.push_back(node_name + "/Read/ReadVariableOp:0"); + } else { + tensor_names.push_back(node_name + ":0"); + } } std::vector outputs; TF_RETURN_IF_ERROR( @@ -143,6 +148,15 @@ void ConvertVariableToConstant(const NodeDef& variable_node, (*const_node->mutable_attr())["value"].mutable_tensor()); } +// Converts a ReadVariableOp NodeDef to an Identity NodeDef. +void ConvertReadVariableOpToIdentity(const NodeDef& node, + NodeDef* identity_node) { + identity_node->set_name(node.name()); + identity_node->set_op("Identity"); + (*identity_node->mutable_attr())["T"] = node.attr().at("dtype"); + identity_node->add_input(node.input(0)); +} + // Freezes the subgraph of all nodes needed by `outputs`. Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, const std::unordered_set& outputs, @@ -155,14 +169,19 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, if (graph_def.node_size() == 0) { return Status::OK(); } + // name_to_node_map is needed to get the inputs from the NodeDef corresponding + // the a string node name. These inputs are used when doing our backwards + // traversal. + std::unordered_map name_to_node_map; + GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map); std::unordered_set reachable_node_names; std::unordered_set variable_node_names; - GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names, - &variable_node_names); + GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map, + &reachable_node_names, &variable_node_names); std::unordered_map variable_to_value_map; - TF_RETURN_IF_ERROR( - GetVariableNameToTensorMap(saved_model_bundle.session.get(), - variable_node_names, &variable_to_value_map)); + TF_RETURN_IF_ERROR(GetVariableNameToTensorMap( + saved_model_bundle.session.get(), name_to_node_map, variable_node_names, + &variable_to_value_map)); // We copy the nodes in the same order they were in the original graph_def. for (const NodeDef& node : graph_def.node()) { if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { @@ -171,6 +190,12 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, if (variable_node_names.find(node.name()) != variable_node_names.end()) { ConvertVariableToConstant(node, variable_to_value_map[node.name()], frozen_graph_def->add_node()); + } else if (node.op() == "ReadVariableOp" && + variable_node_names.find(node.input(0)) != + variable_node_names.end()) { + // If the node is a ReadVariableOp, its input VarHandleOp will be + // converted to a Constant, so we will need to convert it to an Identity. + ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node()); } else { // If the node isn't a variable, just copy the node as-is. *frozen_graph_def->add_node() = node; diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 52a81a50284aec36bba4e56a0232c886cb0cb6cf..cd35fd3b95deec669218cfa4f25fea2c3ac9e56e 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" @@ -113,6 +114,160 @@ class FreezeTest : public ::testing::Test { test::ExpectTensorEqual(unfrozen_outputs[0], frozen_outputs[0]); } + + void TestFreezeGraphWithoutDependentVariables(bool use_resource) { + // Test freezing a graph with variables that are not needed by the outputs + // in the SignatureDef. The resulting graph shouldn't be frozen, but + // non-dependent nodes should be pruned. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + Output read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + } else { + Output var = + ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + } + + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + GraphDef expected_graph_def; + Scope expected_scope = Scope::NewRootScope(); + Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); + Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); + Output expected_c = + ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); + TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); + + GraphDefEqual(frozen_graph_def, expected_graph_def); + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } + + void TestFreezeGraphWithDependentVariables(bool use_resource) { + // Test freezing a graph with variables that are needed by outputs in the + // SignatureDef. The variables should be frozen. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output read_var; + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + } else { + Output read_var = + ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a); + } + Output c = ops::Mul(scope.WithOpName("c"), a, read_var); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + // If using normal variables there should be 3 nodes in the resulting + // graph_def. If using resource variables there should be 4 nodes in the + // resulting graph_def. + // In both cases, none should be variables. + size_t expected_nodes = use_resource ? 4 : 3; + EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + EXPECT_NE(node.op(), "VarHandleOp") << node.name(); + EXPECT_NE(node.op(), "ReadVariableOp") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } + + void TestFreezeGraphWithAndWithoutDependentVariables(bool use_resource) { + // Test freezing a graph with some variables that are needed and not needed + // by + // the outputs in the SignatureDef. The resulting graph should only freeze + // dependent variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output read_var; + + if (use_resource) { + Output var = + ops::VarHandleOp(scope.WithOpName("var"), DataType::DT_FLOAT, {}); + read_var = ops::ReadVariableOp( + scope.WithOpName("var/Read/ReadVariableOp"), var, DataType::DT_FLOAT); + auto assign = ops::AssignVariableOp(scope.WithOpName("assign"), var, a); + Output var_1 = + ops::VarHandleOp(scope.WithOpName("var_1"), DataType::DT_FLOAT, {}); + Output read_var_1 = + ops::ReadVariableOp(scope.WithOpName("var_1/Read/ReadVariableOp"), + var, DataType::DT_FLOAT); + auto assign_1 = + ops::AssignVariableOp(scope.WithOpName("assign_1"), var_1, a); + } else { + read_var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), read_var, a); + Output var_1 = + ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); + Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var_1, a); + } + + Output c = ops::Mul(scope.WithOpName("c"), a, read_var); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, "assign", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, + &inputs, &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + size_t expected_nodes = use_resource ? 4 : 3; + EXPECT_EQ(frozen_graph_def.node_size(), expected_nodes); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + EXPECT_NE(node.op(), "VarHandleOp") << node.name(); + EXPECT_NE(node.op(), "ReadVariableOp") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); + } }; TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) { @@ -196,111 +351,28 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } -TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { - // Test freezing a graph with variables that are not needed by the outputs in - // the SignatureDef. The resulting graph shouldn't be frozen, but - // non-dependent nodes should be pruned. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); - Output c = ops::Mul(scope.WithOpName("c"), a, b); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); - - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); - - GraphDef expected_graph_def; - Scope expected_scope = Scope::NewRootScope(); - Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); - Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); - Output expected_c = - ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); - TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); - - GraphDefEqual(frozen_graph_def, expected_graph_def); - - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { + TestFreezeGraphWithoutDependentVariables(false); } -TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { - // Test freezing a graph with variables that are needed by outputs in the - // SignatureDef. The variables should be frozen. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output c = ops::Mul(scope.WithOpName("c"), a, var); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); - - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); - - // There should be 3 nodes in the resulting graph_def, and none should be - // variables. - EXPECT_EQ(frozen_graph_def.node_size(), 3); - for (const NodeDef& node : frozen_graph_def.node()) { - EXPECT_NE(node.op(), "Variable") << node.name(); - EXPECT_NE(node.op(), "VariableV2") << node.name(); - } - - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithoutDependentResourceVariables) { + TestFreezeGraphWithoutDependentVariables(true); } -TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) { - // Test freezing a graph with some variables that are needed and not needed by - // the outputs in the SignatureDef. The resulting graph should only freeze - // dependent variables. - SavedModelBundle saved_model_bundle; - GraphDef graph_def; - Scope scope = Scope::NewRootScope(); - Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); - Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); - Output c = ops::Mul(scope.WithOpName("c"), a, var); - Output assign = ops::Assign(scope.WithOpName("assign"), var, a); - Output var_1 = - ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); - Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a); - TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. - TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( - graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); +TEST_F(FreezeTest, GraphDefWithDependentVariables) { + TestFreezeGraphWithDependentVariables(false); +} - GraphDef frozen_graph_def; - std::unordered_set inputs; - std::unordered_set outputs; - TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, - &outputs)); +TEST_F(FreezeTest, GraphDefWithDependentResourceVariables) { + TestFreezeGraphWithDependentVariables(true); +} - // There should be 3 nodes in the resulting graph_def, and none should be - // variables. - EXPECT_EQ(frozen_graph_def.node_size(), 3); - for (const NodeDef& node : frozen_graph_def.node()) { - EXPECT_NE(node.op(), "Variable") << node.name(); - EXPECT_NE(node.op(), "VariableV2") << node.name(); - } +TEST_F(FreezeTest, GraphDefWithAndWithoutDependentVariables) { + TestFreezeGraphWithAndWithoutDependentVariables(false); +} - RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), - frozen_graph_def, "c:0"); +TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) { + TestFreezeGraphWithAndWithoutDependentVariables(true); } } // namespace diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc index 3675d72ee354533a7d84b5e8783cde452d8d60c9..5dbc4f5f6aa389978e55ca2656c17ff97202203d 100644 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ b/tensorflow/cc/tutorials/example_trainer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/graph/default_device.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -166,7 +167,8 @@ namespace { bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, int32* dst) { - if (arg.Consume(flag) && arg.Consume("=")) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag) && + tensorflow::str_util::ConsumePrefix(&arg, "=")) { char extra; return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); } @@ -176,7 +178,7 @@ bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, bool* dst) { - if (arg.Consume(flag)) { + if (tensorflow::str_util::ConsumePrefix(&arg, flag)) { if (arg.empty()) { *dst = true; return true; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 0900e87ebabd378e6237b77ca0ef01677c07c244..19e6bf68e77725bb3cae4e1d338c52dff472cb18 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -60,6 +60,7 @@ cc_library( "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -72,6 +73,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], ) @@ -249,17 +251,3 @@ exports_files([ "benchmark_main.template", # used by tf_library(...,gen_benchmark=True) "test.cc", # used by tf_library(...,gen_test=True) ]) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 972b7d51ecb3798e61757ac55e973075a23b433a..2642536c4f67eba8eedf315f24d800e7913d62a0 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -33,7 +34,7 @@ namespace { void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd..7c833878818022c86fd3171ec9cef9fcd3217a24 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 6489929a576d6469c4ff1358ca5ee9d27fb578bb..0048eec93bbe10271d9aa535203f19473a38b342 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" -#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc index 5772776666129ed55a479c8917e69df3f3ce2fc0..5e74079fc158379b8977ada6412141e39142c3d3 100644 --- a/tensorflow/compiler/aot/runtime.cc +++ b/tensorflow/compiler/aot/runtime.cc @@ -31,7 +31,7 @@ namespace { inline void* aligned_malloc(size_t size, int minimum_alignment) { #if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN) return memalign(minimum_alignment, size); -#elif defined(COMPILER_MSVC) +#elif defined(_WIN32) return _aligned_malloc(size, minimum_alignment); #else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN void* ptr = nullptr; @@ -48,7 +48,7 @@ inline void* aligned_malloc(size_t size, int minimum_alignment) { } inline void aligned_free(void* aligned_memory) { -#if defined(COMPILER_MSVC) +#if defined(_WIN32) _aligned_free(aligned_memory); #else free(aligned_memory); diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 28aab6eb614ca7123d9e00f7f5cc3661b62e23f7..bb73cb19c57a654058af5bbb4535c76b0aca8e8c 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -14,6 +14,7 @@ test_suite( ":test_graph_tfadd_test", ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", + ":test_graph_tfassert_eq_test", ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", @@ -33,6 +34,7 @@ py_binary( "//tensorflow/python", # TODO(b/34059704): remove when fixed "//tensorflow/python:array_ops", "//tensorflow/python:client", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform", @@ -52,6 +54,7 @@ genrule( "test_graph_tfadd_with_ckpt_saver.ckpt", "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", + "test_graph_tfassert_eq.pb", "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", @@ -104,6 +107,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfassert_eq", + testonly = 1, + config = "test_graph_tfassert_eq.config.pbtxt", + cpp_class = "AssertComp", + graph = "test_graph_tfassert_eq.pb", + tags = [ + "manual", + ], +) + tf_library( name = "test_graph_tffunction", testonly = 1, @@ -170,6 +184,7 @@ tf_cc_test( ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", + ":test_graph_tfassert_eq", ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", @@ -182,17 +197,3 @@ tf_cc_test( "//third_party/eigen3", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 89c7cd4507cbd476104a039d6083d8f89de11278..67767f55dae9b15aafbd8b129328bde2c59a9ef3 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app @@ -125,6 +126,14 @@ def tfsplits(_): array_ops.identity(y, name='result') +def tfassert_eq(_): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + control_flow_ops.Assert( + math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') + math_ops.add(x, math_ops.negative(y), name='x_y_diff') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -144,6 +153,7 @@ def main(_): write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tffunction, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) + write_graph(tfassert_eq, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..8732d1709e809bb47d3769c483483c2c4f350e1c --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt @@ -0,0 +1,16 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "x_y_diff" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 413efd9cea3b6f71574615ad9ca92471ff925781..67dbd643bfc7bf2c214e7eb5ae8bd2cc7d6e164b 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h" #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" @@ -413,6 +414,23 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, AssertEqAndReturnDiff) { + // Assert is converted into a no-op in XLA, so there is no failure even if the + // two args are different. + AssertComp assert; + EXPECT_EQ(assert.arg0_data(), assert.args()[0]); + EXPECT_EQ(assert.arg1_data(), assert.args()[1]); + + assert.arg0() = 2; + assert.arg1() = 1; + const int32 expected_result = assert.arg0() - assert.arg1(); + EXPECT_TRUE(assert.Run()); + EXPECT_EQ(assert.error_msg(), ""); + EXPECT_EQ(assert.result0(), expected_result); + EXPECT_EQ(assert.result0_data()[0], expected_result); + EXPECT_EQ(assert.result0_data(), assert.results()[0]); +} + TEST(TFCompileTest, LookupNameIndex) { // add doesn't have any names defined in its config. AddComp add; diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 9dff1be09fede6f65f82c2f36d94be07e781949f..3a877c5337ff76193a7f27fb9681e5a9ca500961 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -132,7 +132,7 @@ def tf_library(name, graph, config, header_file = name + ".h" metadata_object_file = name + "_tfcompile_metadata.o" function_object_file = name + "_tfcompile_function.o" - ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + ep = ("__" + native.package_name() + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags else: diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index e2f01179d4e2e4f6ef72b2761d06e130ffa3a94f..8ea014c2eede2cb7a9cede9dd4ade8b970bd519c 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -55,7 +55,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (StringPiece(fname).ends_with(".pbtxt")) { + if (str_util::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a711319607f4ff2b83aa0ebe50e215b3d0e2258e..50fa95c4f322e85c22f7be2d63f2bcd194ee419e 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -29,7 +29,10 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", - visibility = [":friends"], + visibility = [ + ":friends", + "//learning/tfx:__subpackages__", + ], deps = [ ":xla_cpu_device", ":xla_cpu_jit", @@ -73,6 +76,7 @@ cc_library( ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep @@ -102,22 +106,46 @@ cc_library( cc_library( name = "xla_interpreter_device", srcs = ["xla_interpreter_device.cc"], + visibility = [":friends"], deps = [ + ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_tensor", + srcs = ["xla_tensor.cc"], + hdrs = ["xla_tensor.h"], + deps = [ + ":common", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], - alwayslink = True, ) cc_library( name = "xla_device", srcs = [ + "xla_compile_on_demand_op.cc", "xla_device.cc", "xla_device_context.cc", "xla_device_ops.cc", ], hdrs = [ + "xla_compile_on_demand_op.h", "xla_device.h", "xla_device_context.h", "xla_device_ops.h", @@ -127,6 +155,8 @@ cc_library( deps = [ ":common", ":jit_compilation_passes", + ":xla_launch_util", + ":xla_tensor", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:dump_graph", @@ -153,6 +183,13 @@ cc_library( ], ) +cc_library( + name = "shape_inference_helpers", + srcs = ["shape_inference_helpers.cc"], + hdrs = ["shape_inference_helpers.h"], + deps = ["//tensorflow/core:graph"], +) + # Internal targets below this point. cc_library( @@ -166,6 +203,29 @@ cc_library( visibility = [":friends"], ) +cc_library( + name = "xla_launch_util", + srcs = ["xla_launch_util.cc"], + hdrs = ["xla_launch_util.h"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:variable_ops", + ], +) + cc_library( name = "xla_compilation_cache", srcs = ["xla_compilation_cache.cc"], @@ -200,6 +260,7 @@ cc_library( name = "graph_to_functiondef", srcs = ["graph_to_functiondef.cc"], hdrs = ["graph_to_functiondef.h"], + visibility = [":friends"], deps = [ "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -239,6 +300,7 @@ cc_library( deps = [ ":common", ":graph_to_functiondef", + ":shape_inference_helpers", ":union_find", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", @@ -256,6 +318,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", ], ) @@ -264,6 +327,25 @@ cc_library( hdrs = ["union_find.h"], ) +cc_library( + name = "producer_consumer_queue", + hdrs = ["producer_consumer_queue.h"], + deps = ["//tensorflow/core:lib"], +) + +tf_cc_test( + name = "producer_consumer_queue_test", + size = "small", + srcs = ["producer_consumer_queue_test.cc"], + deps = [ + ":producer_consumer_queue", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "graph_to_functiondef_test", size = "small", @@ -296,6 +378,7 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", + ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -306,26 +389,13 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9c372a012789fc25ca0a711349c09ca62edc6754..9465385b5856baf4d03f280ff30572e196a7663b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" @@ -53,6 +55,8 @@ namespace tensorflow { const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; namespace { @@ -143,7 +147,7 @@ struct NodeSlot { // everything to use it. static const char* const kArgOp = "_Arg"; static const char* const kRetValOp = "_Retval"; -static const char* const kHostComputeOp = "_XlaHostCompute"; +static const char* const kHostComputeOp = "XlaHostCompute"; static const char* const kSendFromHostOp = "_XlaSendFromHost"; static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; @@ -252,7 +256,8 @@ class Encapsulator { // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( - const string& subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const std::unordered_map& node_images, Graph* graph_out); @@ -328,12 +333,14 @@ class Encapsulator { Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to - // all the downstream nodes of call_node_outputs. - void ConnectSequencerToOutputs(Graph* graph_out); + // the call node. + void ConnectSequencerToCallNode(Graph* graph_out); Status AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph); + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library); Status ReplaceFunctionDef(FunctionLibraryDefinition* library); @@ -381,15 +388,29 @@ class Encapsulator { Node* send_from_host = nullptr; }; + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Returns the (possible newly created) subgraph for + // outside_compilation_id. + OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( + const string& outside_compilation_id); + // Builds a ParallelCheck op that compares the output of the original // subgraph with the encapsulated subgraph. Status BuildParallelCheckOp( const std::unordered_map& node_images, Graph* graph_out); + // Builds a placeholder node used to provide the key input to a RecvAtHost + // or SendFromHost node. This placeholder node will be removed by a later + // pass. + Status AddHostComputeKeyPlaceholder(OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); + // Builds a _RecvAtHost node producing all the inputs of an // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. - Status AddRecvAtHostNode(const string& subgraph_name, + Status AddRecvAtHostNode(const string& group_attribute, + const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); @@ -398,8 +419,10 @@ class Encapsulator { // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. Status AddSendFromHostNode( const std::unordered_map& node_images, - const string& subgraph_name, const string& oc_subgraph_name, - OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, + const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are @@ -413,6 +436,14 @@ class Encapsulator { // NodeDef for the function call node. NodeDef call_node_def_; + // Name that is used for the call node. This may not be + // call_node_def_.name() if the client supplies a rewrite lambda. + string function_def_name_; + + // Placeholder node simulating the host compute key in the output graph. + // Not owned. + Node* host_compute_key_placeholder_ = nullptr; + // Function call node(s) in the output graph. Not owned. // If parallel_checking is enabled, 'call_node_inputs' is the function call // node to which inputs should be fed, and 'call_node_outputs' is the @@ -547,11 +578,12 @@ class Encapsulator { // satisfied, e.g., because send_node depends on a node that doesn't have a // registered shape inference function. Status DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const ShapeRefiner& shape_refiner, + const Graph& graph_in, const BackEdgeHelper& back_edge_helper, + const ShapeRefiner& shape_refiner, const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out); + std::unique_ptr* graph_out); // Makes a copy of graph containing only nodes that are ancestors of at least // one node in send_from_host_nodes and store it in pruned_graph. On exit @@ -570,7 +602,7 @@ class Encapsulator { // to nodes in pruned_graph. Status MakeGraphForOutsideCompilationSends( const Graph& graph, std::unique_ptr* pruned_graph, - ShapeRefiner* shape_refiner, + BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, std::unordered_map* node_images, FunctionLibraryDefinition* library); @@ -712,39 +744,44 @@ Status Encapsulator::Subgraph::RecordResult( return Status::OK(); } -void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( - const string& outside_compilation_id, const Edge* edge) { +Encapsulator::Subgraph::OutsideCompilationSubgraph* +Encapsulator::Subgraph::LookupOrCreateOutsideCompilationSubgraph( + const string& outside_compilation_id) { auto iter = outside_compilation_subgraphs_ .emplace(outside_compilation_id, OutsideCompilationSubgraph()) .first; - OutsideCompilationSubgraph& outside_subgraph = iter->second; + OutsideCompilationSubgraph* outside_subgraph = &iter->second; + return outside_subgraph; +} + +void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge) { + OutsideCompilationSubgraph* outside_subgraph = + LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); if (edge->IsControlEdge()) { - outside_subgraph.control_inputs.insert(edge->src()); + outside_subgraph->control_inputs.insert(edge->src()); } else { - int input_index = outside_subgraph.inputs.size(); - outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + int input_index = outside_subgraph->inputs.size(); + outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), + input_index); } } void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge) { - auto subgraph_iter = - outside_compilation_subgraphs_ - .emplace(outside_compilation_id, OutsideCompilationSubgraph()) - .first; - OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; + OutsideCompilationSubgraph* outside_subgraph = + LookupOrCreateOutsideCompilationSubgraph(outside_compilation_id); if (edge->IsControlEdge()) { - outside_subgraph.control_outputs.insert(edge->dst()); + outside_subgraph->control_outputs.insert(edge->dst()); } else { DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = - outside_subgraph.outputs_by_src + outside_subgraph->outputs_by_src .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph.outputs_by_src.size()) + outside_subgraph->outputs_by_src.size()) .first; int output_index = output_iter->second; - outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = output_index; } } @@ -791,6 +828,7 @@ Status Encapsulator::Subgraph::AddHostComputes( builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -842,25 +880,21 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, NodeDef seq_def; NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), "NoOp"); + builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); + builder.Device(device_); Status s = builder.Finalize(&seq_def); if (!s.ok()) return s; sequencer_ = graph_out->AddNode(seq_def, &s); if (!s.ok()) return s; - sequencer_->set_assigned_device_name(device_); } return Status::OK(); } -void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { +void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { - std::unordered_set output_dependencies; - for (Node* node : call_node_outputs_->out_nodes()) { - output_dependencies.insert(node); - } - for (Node* node : output_dependencies) { - graph_out->AddControlEdge(sequencer_, node); - } + VLOG(2) << "ConnectSequencerToCallNode"; + graph_out->AddControlEdge(sequencer_, call_node_inputs_); } } @@ -906,6 +940,8 @@ Status Encapsulator::Subgraph::BuildFunctionDef( name = call_node_def_.op(); } + function_def_name_ = name; + FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -924,8 +960,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( } Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& subgraph_name, const string& outside_compilation_subgraph_name, - const std::vector& shapes, GraphDef* inference_graph) { + const std::vector& shapes, Graph* inference_graph, + FunctionLibraryDefinition* library) { OutsideCompilationSubgraph& oc_subgraph = outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); @@ -947,21 +985,22 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shape_inference_graph", ""); host_compute->AddAttr("shapes", shapes); } else { - string serialized_graph; - if (!inference_graph->SerializeToString(&serialized_graph)) { - return errors::Internal( - "Failed to serialize graph for outside compilation subgraph ", - oc_subgraph.host_compute_name); - } - host_compute->AddAttr("shape_inference_graph", serialized_graph); + string inference_graph_name = + strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); + FunctionDef fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); + host_compute->AddAttr("shape_inference_graph", inference_graph_name); host_compute->AddAttr("shapes", std::vector()); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); } return Status::OK(); } Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { - const string& name = call_node_def_.name(); + const string& name = function_def_name_; FunctionDef fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); @@ -1060,9 +1099,37 @@ Status Encapsulator::Subgraph::AddFunctionCallNode( return Status::OK(); } +Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + TensorShapeProto shape_proto; + TensorShape shape({2}); + shape.AsProto(&shape_proto); + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeDef key_def; + NodeDefBuilder builder( + strings::StrCat(call_node_def_.name(), "_key_placeholder"), + "Placeholder"); + builder.Attr("dtype", DT_STRING); + builder.Attr("shape", shape_proto); + builder.Attr("_host_compute_call_node", call_node_def_.name()); + Status s = builder.Finalize(&key_def); + if (!s.ok()) return s; + + host_compute_key_placeholder_ = graph_out->AddNode(key_def, &s); + if (!s.ok()) return s; + host_compute_key_placeholder_->set_assigned_device_name(device_); + + return Status::OK(); +} + Status Encapsulator::Subgraph::AddRecvAtHostNode( - const string& subgraph_name, const string& oc_subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + if (host_compute_key_placeholder_ == nullptr) { + TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); + } + std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); for (const auto& input : oc_subgraph->inputs) { @@ -1078,15 +1145,23 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); + builder.Device(device_); builder.Attr("Toutputs", dtypes); + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. + builder.Attr("device_ordinal", 0); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + builder.Attr(group_attribute, subgraph_name); + builder.Attr(outside_compilation_attribute, oc_subgraph_name); + builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&recv_def); if (!s.ok()) return s; oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); if (!s.ok()) return s; - oc_subgraph->recv_at_host->set_assigned_device_name(device_); + graph_out->AddEdge(host_compute_key_placeholder_, 0, + oc_subgraph->recv_at_host, 0); // Add a control dependency forcing the RecvAtHost to run before the subgraph // completes. This has no effect on execution order but prevents the @@ -1099,8 +1174,13 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( Status Encapsulator::Subgraph::AddSendFromHostNode( const std::unordered_map& node_images, - const string& subgraph_name, const string& oc_subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const string& oc_subgraph_name, OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + if (host_compute_key_placeholder_ == nullptr) { + TF_RETURN_IF_ERROR(AddHostComputeKeyPlaceholder(oc_subgraph, graph_out)); + } + std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); std::vector inputs( oc_subgraph->outputs_by_src.size()); @@ -1120,16 +1200,24 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, "_", oc_subgraph_name, "_send"), kSendFromHostOp); + builder.Device(device_); builder.Attr("Tinputs", dtypes); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); + // The correct device_ordinal will be inserted during replication in a + // subsequent rewrite. + builder.Attr("device_ordinal", 0); + builder.Attr(group_attribute, subgraph_name); + builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(inputs); + builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); Status s = builder.Finalize(&send_def); if (!s.ok()) return s; oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); if (!s.ok()) return s; - oc_subgraph->send_from_host->set_assigned_device_name(device_); + graph_out->AddEdge(host_compute_key_placeholder_, 0, + oc_subgraph->send_from_host, inputs.size()); // Add a control dependency forcing the SendFromHost to run before the // subgraph completes. This has no effect on execution order but prevents the @@ -1141,7 +1229,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( - const string& subgraph_name, + const string& group_attribute, const string& subgraph_name, + const string& outside_compilation_attribute, const std::unordered_map& node_images, Graph* graph_out) { for (auto& outside_compilation_subgraph_entry : @@ -1151,14 +1240,16 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( outside_compilation_subgraph_entry.second; if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { - TF_RETURN_IF_ERROR( - AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); + TF_RETURN_IF_ERROR(AddRecvAtHostNode(group_attribute, subgraph_name, + outside_compilation_attribute, + oc_name, &oc_subgraph, graph_out)); } if (!oc_subgraph.outputs_by_src.empty() || !oc_subgraph.control_outputs.empty()) { - TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, - oc_name, &oc_subgraph, graph_out)); + TF_RETURN_IF_ERROR(AddSendFromHostNode( + node_images, group_attribute, subgraph_name, + outside_compilation_attribute, oc_name, &oc_subgraph, graph_out)); } } return Status::OK(); @@ -1375,8 +1466,6 @@ Status Encapsulator::CopyNodesToOutputGraph( "Parallel checking is not supported when outside_compilation " "clusters are present."); } - image->ClearAttr(group_attribute_); - image->ClearAttr(outside_compilation_attribute_); } (*node_images)[node] = image; } @@ -1402,7 +1491,8 @@ Status Encapsulator::AddOutsideCompilationHostIONodes( const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( - subgraph_name, node_images, graph_out)); + group_attribute_, subgraph_name, outside_compilation_attribute_, + node_images, graph_out)); } return Status::OK(); } @@ -1611,7 +1701,7 @@ Status Encapsulator::AddEdgesToOutputGraph( for (auto& subgraph_entry : subgraphs_) { Subgraph& subgraph = subgraph_entry.second; - subgraph.ConnectSequencerToOutputs(graph_out); + subgraph.ConnectSequencerToCallNode(graph_out); } return Status::OK(); @@ -1625,9 +1715,13 @@ namespace { // matter because it will only be used subsequently for shape inference. (It // would be possible to add a switch statement over data_type to create a value // for the constant, but that would entail maintaining the logic as new types -// are added, and is not necessary.) -Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, - Graph* graph_out) { +// are added, and is not necessary.) If the node being replaced was within a +// control flow frame, adds appropriate Enter nodes so that the use of the Const +// is well-formed. +Node* AddDummyShapedNode(const Node* src_node, int src_port, + const std::vector& control_flow_info, + const TensorShapeProto& shape, Graph* graph_out) { + DataType data_type = src_node->output_type(src_port); TensorProto dummy_proto; dummy_proto.set_dtype(data_type); *dummy_proto.mutable_tensor_shape() = shape; @@ -1638,7 +1732,23 @@ Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", options.op_registry()); node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); - return options.FinalizeBuilder(&node_builder); + Node* node = options.FinalizeBuilder(&node_builder); + // Add any Enter nodes required to bring the constant to the correct control + // flow frame. + while (!control_flow_info[src_node->id()].frame_name.empty()) { + NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter", + options.op_registry()); + enter_builder.Attr("frame_name", + control_flow_info[src_node->id()].frame_name); + enter_builder.Attr("is_constant", true); + enter_builder.Input(node, 0); + Node* enter_node = options.FinalizeBuilder(&enter_builder); + // Adopt the new Enter node as the value in the current frame. + node = enter_node; + // Recurse to the parent frame to see if more Enter nodes need to be added. + src_node = control_flow_info[src_node->id()].parent_frame; + } + return node; } // Adds a copy of node_in to graph_out and adds the mapping to @@ -1680,17 +1790,30 @@ Status CopyShapeInferenceNodeToGraph( } } } + // Work around the fact that Enter nodes refuse to propagate shape information + // unless they are marked loop invariant. Since we are never going to execute + // this graph, marking them all loop invariant is fine. + if (node_out->type_string() == "Enter") { + node_out->ClearAttr("is_constant"); + node_out->AddAttr("is_constant", true); + } return Status::OK(); } } // namespace Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( - const Graph& graph_in, const ShapeRefiner& shape_refiner, + const Graph& graph_in, const BackEdgeHelper& back_edge_helper, + const ShapeRefiner& shape_refiner, const std::unordered_set& recv_at_host_nodes, Node* send_node, FunctionLibraryDefinition* library, std::vector* static_shape_out, - std::unique_ptr* graphdef_out) { + std::unique_ptr* graph_out) { + // Get the control flow structure of the input graph so we can build + // well-formed output graphs. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph_in, &control_flow_info)); + // Maps from nodes in graph_in to nodes in graph_out. // // When an edge has fully defined shape the source node in graph_in is @@ -1707,13 +1830,14 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( std::unordered_map dummy_node_images; std::unordered_map copied_node_images; - std::unique_ptr graph_out(new Graph(graph_in.op_registry())); - graph_out->set_versions(graph_in.versions()); - static_shape_out->resize(send_node->num_inputs()); + graph_out->reset(new Graph(graph_in.op_registry())); + (*graph_out)->set_versions(graph_in.versions()); + // The final input to the send node is the dynamic key, which we don't include + // in the static shapes. + static_shape_out->resize(send_node->num_inputs() - 1); // We don't use the standard ReverseDFS because we want to cut off traversal // whenever we find an output with fully defined shape. - // TODO(misard) make this work properly in the presence of control flow. struct Work { Node* node; bool leave; // Are we entering or leaving node? @@ -1728,7 +1852,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( if (w.leave) { TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( n, send_node, dummy_node_images, library, &copied_node_images, - graph_out.get())); + graph_out->get())); } else { if (visited[n->id()]) continue; visited[n->id()] = true; @@ -1750,14 +1874,24 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // continue. TensorShapeProto proto; context->ShapeHandleToProto(shape, &proto); - dummy_node_images[src_node] = AddDummyShapedNode( - src_node->output_type(src_port), proto, graph_out.get()); - if (n == send_node) { + if (dummy_node_images.find(src_node) == dummy_node_images.end()) { + dummy_node_images[src_node] = + AddDummyShapedNode(src_node, src_port, control_flow_info, + proto, graph_out->get()); + } + // The final input to the send node is the dynamic key, which we + // don't include in the static shapes. + if (n == send_node && + in_edge->dst_input() < static_shape_out->size()) { (*static_shape_out)[in_edge->dst_input()] = proto; } } else { + has_parent_with_unknown_shape = true; if (!visited[src_node->id()]) { - has_parent_with_unknown_shape = true; + if (VLOG_IS_ON(2)) { + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + } stack.push_back({src_node, false}); } } @@ -1768,7 +1902,7 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( // The shapes of all the inputs to send_node are statically known. We // won't have to do any inference at compile time so return now: the // shapes were stored in static_shape_out above. - graphdef_out->reset(); + graph_out->reset(); return Status::OK(); } else { // Any shape that is being processed is either the original send node @@ -1791,8 +1925,37 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( } } - graphdef_out->reset(new GraphDef()); - graph_out->ToGraphDef(graphdef_out->get()); + for (const auto edge : back_edge_helper.RemovedEdges()) { + if (copied_node_images.find(edge.dst) != copied_node_images.end()) { + // The destination of this back edge was added to the inference graph, so + // fix it up. + Node* dst = copied_node_images[edge.dst]; + if (dst->type_string() != "Merge") { + return errors::InvalidArgument( + "outside_compilation cluster contains a back-edge to node ", + dst->name(), " of type ", dst->type_string(), + ". The analysis pass only supports back-edges to Merge nodes."); + } + const Edge* existing_input_edge; + if (edge.dst_input != 1 || dst->num_inputs() != 2 || + !dst->input_edge(0, &existing_input_edge).ok()) { + // TODO(misard) if we see graphs built with a different structure, relax + // this constraint. Leaving it here for now to avoid writing unnecessary + // complex code since we believe graphs generated by front ends all have + // the back edge as the second input to the merge node. + return errors::Internal( + "Internal assumption failed while rewriting an outside_compilation " + "cluster that contains a while loop. Logic assumes back-edge is to " + "port 1 of a 2-input " + "Merge node."); + } + // Connect the existing edge to both inputs of the Merge node so that the + // graph will be well-formed. + (*graph_out) + ->AddEdge(existing_input_edge->src(), + existing_input_edge->src_output(), dst, edge.dst_input); + } + } return Status::OK(); } @@ -1861,7 +2024,7 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( Status Encapsulator::MakeGraphForOutsideCompilationSends( const Graph& graph, std::unique_ptr* pruned_graph, - ShapeRefiner* shape_refiner, + BackEdgeHelper* back_edge_helper, ShapeRefiner* shape_refiner, std::unordered_map* node_images, FunctionLibraryDefinition* library) { // Find all the send_from_host nodes in all subgraphs, to use as roots for the @@ -1883,10 +2046,15 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends( // nodes, inlining any functions as needed. TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( graph, send_from_host_nodes, pruned_graph, node_images, library)); + FixupSourceAndSinkEdges(pruned_graph->get()); + + // Remove back edges from any cycles in the pruned graph to simplify shape + // inference traversal. They will be fixed up in the per-subgraph shape + // inference graphs stored in the function library. + TF_RETURN_IF_ERROR(back_edge_helper->Remove(pruned_graph->get())); // Perform shape inference on the pruned graph. shape_refiner->set_require_shape_inference_fns(false); - FixupSourceAndSinkEdges(pruned_graph->get()); std::vector post_order; GetReversePostOrder(*(*pruned_graph), &post_order); for (auto node : post_order) { @@ -1904,20 +2072,28 @@ Status Encapsulator::MakeGraphForOutsideCompilationSends( Status Encapsulator::GetShapeInfoForOutsideCompilationSends( Graph* graph_out, FunctionLibraryDefinition* library) { + BackEdgeHelper back_edge_helper; std::unique_ptr pruned_graph; ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); std::unordered_map node_images; TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( - *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + *graph_out, &pruned_graph, &back_edge_helper, &shape_refiner, + &node_images, library)); + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("pruned_graph_for_shape_inference", + *pruned_graph, library); + } for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; // Find all the recv_at_host nodes in this subgraph. std::vector outside_compilation_names; subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); std::unordered_set recv_at_host_names; - for (const auto& name : outside_compilation_names) { - Node* recv_node = subgraph.GetRecvAtHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(oc_name); if (recv_node != nullptr) { recv_at_host_names.insert(recv_node->name()); } @@ -1926,26 +2102,30 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( // without knowing the shape of the recv_at_host nodes, and store the // result, along with enough information to complete the job at compile time // once the recv_at_host shapes are known. - for (const auto& name : outside_compilation_names) { - Node* send_node = subgraph.GetSendFromHostNode(name); + for (const auto& oc_name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(oc_name); std::vector static_shape; - std::unique_ptr graphdef; + std::unique_ptr graph; if (send_node != nullptr) { TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( - *pruned_graph, shape_refiner, recv_at_host_names, - node_images[send_node], library, &static_shape, &graphdef)); - if (graphdef == nullptr) { + *pruned_graph, back_edge_helper, shape_refiner, recv_at_host_names, + node_images[send_node], library, &static_shape, &graph)); + if (graph == nullptr) { VLOG(2) << "Send node " << send_node->name() << " shapes"; for (int i = 0; i < static_shape.size(); ++i) { VLOG(2) << static_shape[i].DebugString(); } } else { - VLOG(2) << "Send node " << send_node->name() << " graph\n" - << graphdef->DebugString(); + if (VLOG_IS_ON(2)) { + GraphDef graphdef; + graph->ToGraphDef(&graphdef); + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef.DebugString(); + } } } - TF_RETURN_IF_ERROR( - subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + TF_RETURN_IF_ERROR(subgraph.AddShapeInferenceInfo( + subgraph_name, oc_name, static_shape, graph.get(), library)); } if (!outside_compilation_names.empty()) { TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index aed9cae0f1799c4524da8ee309344849798755d5..8599a7038af9663e5af6f3231429cb7f6ea5f69b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -13,22 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { namespace { +const char* const kXlaHostTransferSequencerAttr = + "_xla_host_transfer_sequencer"; + +Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, + const string& name_suffix, + FunctionDefLibrary* library) { + GraphDef graphdef; + TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); + std::unique_ptr graph = + std::unique_ptr(new Graph(OpRegistry::Global())); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get())); + FunctionDef* fdef = library->add_function(); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *graph, + strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + fdef)); + return Status::OK(); +} + template bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const ::tensorflow::protobuf::Map& b, @@ -112,23 +136,7 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, a.attr(), b.attr(), [](const string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, [](const string& key, const AttrValue& av, const AttrValue& bv) { - if (key == "shape_inference_graph") { - // Default serialization of GraphDef is unstable because maps don't - // serialize deterministically. Rather than go through the hoops to - // turn on deterministic serialization of this attr just for this - // test, add logic here to compare determinstically. - GraphDef ga; - if (!ga.ParseFromString(av.s())) { - return false; - } - GraphDef gb; - if (!gb.ParseFromString(bv.s())) { - return false; - } - return EqualGraphDef(ga, gb, nullptr); - } else { - return av.DebugString() == bv.DebugString(); - } + return av.DebugString() == bv.DebugString(); }, strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); @@ -246,26 +254,32 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) -// TODO(misard): remove these fake registrations once there are real Ops to be -// compiled. -REGISTER_OP("_XlaHostCompute") +// These dummy Op registrations are here because the real Op registrations live +// in contrib and there can't be a dependence from this test to contrib. +REGISTER_OP("XlaHostCompute") .Input("inputs: Tinputs") .Output("outputs: Toutputs") .Attr("Tinputs: list(type) >= 0") .Attr("Toutputs: list(type) >= 0") .Attr("key: string") + .Attr("shape_inference_graph: string = ''") + .Attr("shapes: list(shape) >= 0") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaSendFromHost") - .Input("input: Tinputs") + .Input("inputs: Tinputs") + .Input("dynamic_key: string") .Attr("Tinputs: list(type) >= 0") .Attr("key: string") + .Attr("device_ordinal: int") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("_XlaRecvAtHost") - .Output("output: Toutputs") + .Input("dynamic_key: string") + .Output("outputs: Toutputs") .Attr("Toutputs: list(type) >= 0") .Attr("key: string") + .Attr("device_ordinal: int") .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("InputTest") @@ -315,8 +329,13 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); -Node* NoOp(const GraphDefBuilder::Options& opts) { - return ops::SourceOp("NoOp", opts); +Node* Sequencer(const GraphDefBuilder::Options& opts, + const string& call_node_name) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp", + opts.op_registry()); + return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name) + .FinalizeBuilder(&node_builder); } Node* Input(const GraphDefBuilder::Options& opts) { @@ -327,43 +346,85 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShape(const gtl::ArraySlice& shape, - const GraphDefBuilder::Options& opts) { +Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", opts.op_registry()); TensorProto value; - value.set_dtype(DT_FLOAT); + value.set_dtype(dtype); for (int dim : shape) { value.mutable_tensor_shape()->add_dim()->set_size(dim); } return opts.WithAttr("value", value) - .WithAttr("dtype", DT_FLOAT) + .WithAttr("dtype", dtype) .FinalizeBuilder(&node_builder); } -Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, +Node* KnownShape(const gtl::ArraySlice& shape, const GraphDefBuilder::Options& opts) { + return KnownShapeBase(DT_FLOAT, shape, opts); +} + +Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) { + return KnownShapeBase(DT_STRING, {2}, opts); +} + +Node* KeyPlaceholder(const string& call_node, + const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), + NodeBuilder node_builder(opts.GetNameForOp("Placeholder"), "Placeholder", + opts.op_registry()); + TensorShapeProto shape; + shape.add_dim()->set_size(2); + return opts.WithAttr("shape", shape) + .WithAttr("dtype", DT_STRING) + .WithAttr("_host_compute_call_node", call_node) + .FinalizeBuilder(&node_builder); +} + +Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, + const string& oc_cluster, + const gtl::ArraySlice& dtypes, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + string key = + strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = strings::StrCat("outside_compilation_", cluster, "_", + oc_cluster, "_recv"); + NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); + node_builder.Input(std::move(key_input)); return opts.WithAttr("Toutputs", dtypes) .WithAttr("key", key) + .WithAttr("device_ordinal", 0) + .WithAttr("_encapsulate", cluster) + .WithAttr("_outside", oc_cluster) .FinalizeBuilder(&node_builder); } -Node* SendFromHost(const string& key, const std::vector& inputs, +Node* SendFromHost(ops::NodeOut key_input, const string& cluster, + const string& oc_cluster, + const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), + string key = + strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = strings::StrCat("outside_compilation_", cluster, "_", + oc_cluster, "_send"); + NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); + node_builder.Input(std::move(key_input)); std::vector dtypes; for (const auto& node : inputs) { dtypes.push_back(node.dt); } - return opts.WithAttr("key", key) - .WithAttr("Tinputs", dtypes) + return opts.WithAttr("Tinputs", dtypes) + .WithAttr("key", key) + .WithAttr("device_ordinal", 0) + .WithAttr("_encapsulate", cluster) + .WithAttr("_outside", oc_cluster) .FinalizeBuilder(&node_builder); } @@ -711,7 +772,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - StringPiece(n->name()).starts_with("const")) { + str_util::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -756,7 +817,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - StringPiece(n->name()).starts_with("const")) { + str_util::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -806,19 +867,20 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -833,13 +895,15 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -851,24 +915,29 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + b2.opts() + .WithName("E") + .WithControlInputs({recv, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = + b2.opts().WithControlInputs({s}).FinalizeBuilder(&node_builder); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -918,38 +987,43 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected_1; { GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape1.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape1.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape1.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape1_graph; - TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); - EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); } - string shape_string_expected_2; { GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - shape2.opts().WithName("E")); - Node* recv2 = - RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, - shape2.opts().WithName("outside_compilation_F1_O2_recv")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); - SendFromHost("host_compute_channel_F1_O2", {h}, - shape2.opts().WithName("outside_compilation_F1_O2_send")); - GraphDef shape2_graph; - TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); - EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + shape2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT, DT_FLOAT}, shape2.opts()); + Node* h = Binary(ops::NodeOut(recv2, 0), e, + shape2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, shape2.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); } *library_expected.add_function() = FunctionDefHelper::Create( @@ -966,22 +1040,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0", "F:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O2"}, - {"shape_inference_graph", shape_string_expected_2}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O2"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}, {"F"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected_1}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"i_0_retval", "I:o:0"}}); @@ -993,35 +1071,45 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - NodeBuilder node_builder("F1", "F1", lib_def.get()); - node_builder.Input(a).Input(b); - Node* call = b2.opts().FinalizeBuilder(&node_builder); - - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); - - Node* recv2 = - RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O2_recv")); + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* g = Binary(e, ops::NodeOut(recv2, 1), - b2.opts().WithName("G").WithControlInputs({recv2, e})); - Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); + b2.opts() + .WithName("G") + .WithControlInputs({recv2, e}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* h = Binary(ops::NodeOut(recv2, 0), e, + b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); Node* send2 = - SendFromHost("host_compute_channel_F1_O2", {h}, - b2.opts().WithName("outside_compilation_F1_O2_send")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {h}, b2.opts()); - Node* s = NoOp(b2.opts() - .WithName("F1_sequencer") - .WithControlInputs({recv1, send1, recv2, send2})); + Node* s = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2}), + "F1"); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().WithControlInput(s).FinalizeBuilder(&node_builder); - Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + Binary(g, call, b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1070,19 +1158,20 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), - shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } TensorShapeProto shape_proto_expected; @@ -1100,13 +1189,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); @@ -1120,14 +1211,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { "BinaryTest", {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"G:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1138,39 +1230,48 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = InputShaped(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* key_constant1 = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), - b2.opts().WithName("E").WithControlInputs({recv1, b})); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Node* recv2 = - RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* key_constant2 = + KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", + {DT_FLOAT}, b2.opts()); Node* h = Binary(ops::NodeOut(call1, 1), recv2, - b2.opts().WithName("H").WithControlInput(s1)); - Node* send2 = - SendFromHost("host_compute_channel_F2_O1", {h}, - b2.opts().WithName("outside_compilation_F2_O1_send")); + b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts()); + Node* s2 = Sequencer( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), + "F2"); NodeBuilder node_builder2("F2", "F2", lib_def.get()); node_builder2.Input(e).Input(call1); Node* call2 = b2.opts() - .WithControlInputs({s1, e, call1}) + .WithControlInputs({s2, e, call1}) .FinalizeBuilder(&node_builder2); - Node* s2 = NoOp( - b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); - Binary(call2, ops::NodeOut(call2, 1), - b2.opts().WithName("J").WithControlInput(s2)); + Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1218,14 +1319,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1236,16 +1338,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts().WithName("E")); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* send1 = - SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(send1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1294,14 +1402,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { "BinaryTest", {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}}, + gtl::ArraySlice({shape_proto_expected})}, + {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1313,20 +1422,26 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { Node* a = InputShaped(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); + RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, b2.opts()); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithControlInput(recv1) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); Node* send1 = - SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts().WithName("outside_compilation_F1_O1_send")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + Unary(call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1368,13 +1483,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, {{"F"}, "UnaryTest", {"D:o:0"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}}}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1385,16 +1501,22 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInput(recv1), "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1441,13 +1563,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}}}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1458,21 +1581,25 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* recv1 = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = Unary(recv1, b2.opts().WithName("E")); - Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); - Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); - Node* s1 = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); - Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + Binary(e, call1, b2.opts().WithName("G")); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } @@ -1523,7 +1650,10 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { Node* a = Input(b2.opts().WithName("A")); Node* b = Input(b2.opts().WithName("B")); - Node* e = Unary(a, b2.opts().WithName("E")); + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); NodeBuilder node_builder1("F1", "F1", lib_def.get()); node_builder1.Input(a).Input(b); Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); @@ -1569,19 +1699,21 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { FunctionDefLibrary library_expected; GraphDef graphdef_expected; - string shape_string_expected; { GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); - Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - shape.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); - SendFromHost("host_compute_channel_F1_O1", {e}, - shape.opts().WithName("outside_compilation_F1_O1_send")); - GraphDef shape_graph; - TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); - EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_1")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape.opts()); + Node* e = BinaryUnknownShape(known, recv, + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); } *library_expected.add_function() = test::function::XTimesTwo(); @@ -1595,13 +1727,15 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, - "_XlaHostCompute", + "XlaHostCompute", {"c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"key", "host_compute_channel_F1_O1"}, - {"shape_inference_graph", shape_string_expected}, - {"shapes", gtl::ArraySlice({})}}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1614,26 +1748,29 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { Node* b = Input(b2.opts().WithName("B")); Node* c = Unary(a, b2.opts().WithName("C")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = BinaryUnknownShape(c, ops::NodeOut(recv, 0), + b2.opts() + .WithName("E") + .WithControlInputs({recv, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + + Node* s = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder("F1", "F1", lib_def.get()); node_builder.Input(b).Input(c); Node* call = - b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); - - Node* recv = - RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, - b2.opts().WithName("outside_compilation_F1_O1_recv")); - Node* e = BinaryUnknownShape( - c, ops::NodeOut(recv, 0), - b2.opts().WithName("E").WithControlInputs({recv, b})); - Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, - b2.opts() - .WithName("outside_compilation_F1_O1_send") - .WithControlInput(e)); - - Node* s = NoOp( - b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); - - Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + b2.opts().WithControlInputs({s, c}).FinalizeBuilder(&node_builder); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({e})); TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index 6fa21fa6204dcc9446081d07e2a59ccace216713..8f5e11dfa47956f1fdaa4d1ff115affa375c5c73 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -229,7 +230,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) { NodeDef* node_def = fdef->mutable_node_def(n_index); for (int i = 0; i < node_def->input_size(); ++i) { - if (StringPiece(node_def->input(i)).starts_with("^")) { + if (str_util::StartsWith(node_def->input(i), "^")) { // Control input const string normalized = node_names.Renormalize(node_def->input(i).substr(1)); diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 15507b3851751c681044a744c07c247410fb3e2d..676f71a75aede2a7720ae0c8a579d64cc184509a 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -27,17 +27,3 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 9bea5663319c8a25249fdc265cee0191556a7c04..00a6f4075f9a18efc3895b033eb6d08e36088a53 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -14,6 +14,7 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_device", + "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", @@ -40,17 +41,3 @@ cc_library( ], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 6353149e4afdf739fe44dd5c76502ef5d98b8477..f48941fce329313e4484b3c2dd900eeac884ed34 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -40,111 +41,6 @@ namespace gpu = perftools::gputools; namespace tensorflow { -// Adapter class that wraps a Tensorflow allocator as an XLA allocator. -// Assumes that the Tensorflow allocator permits asynchronous deallocation: -// see comment on `AllowsAsynchronousDeallocation()`. -class XlaAllocator : public xla::DeviceMemoryAllocator { - public: - XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); - ~XlaAllocator() override; - xla::StatusOr Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) override; - Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override; - - // Register an Tensor (input or resource variable) with the allocator. If - // the operation returns an alias to one of its inputs, then the allocator - // needs to be able to handle it. - Status RegisterArgument(const Tensor* t); - - // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is - // interpreted as having data type 'dtype' and shape 'shape'. - Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, - const TensorShape& shape, Tensor* tensor) const; - - // The Tensorflow BFC allocator used on GPU allows host-side deallocation - // before GPU execution takes place. Tensorflow uses the ordering of the main - // compute stream to enforce a happens-before relationship between a memory - // allocation and code that reuses the same memory. If Tensorflow adds - // support for multiple GPU streams or allocators with different ordering - // requirements, this code may need to change. - // (This attribute has no effect on CPU.) - bool AllowsAsynchronousDeallocation() const override { return true; } - - private: - OpKernelContext* const op_context_; - - // Map from pointer address to the owning Tensor; used by - // MakeTensorFromBuffer. Also used to automatically release Tensors when the - // allocator is freed. - std::unordered_map tensors_; -}; - -XlaAllocator::XlaAllocator(const gpu::Platform* platform, - OpKernelContext* op_context) - : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} - -XlaAllocator::~XlaAllocator() = default; - -xla::StatusOr XlaAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { - AllocatorAttributes allocator_attrs; - allocator_attrs.set_on_host(false); - - AllocationAttributes allocation_attrs; - allocation_attrs.no_retry_on_failure = !retry_on_failure; - - Tensor t; - Status status = op_context_->allocate_temp( - DT_UINT8, TensorShape({static_cast(size)}), &t, allocator_attrs, - allocation_attrs); - if (!status.ok()) { - VLOG(2) << "Allocation failed " << size; - return status; - } - void* data = - reinterpret_cast(const_cast(t.tensor_data().data())); - tensors_[data] = t; - return gpu::DeviceMemoryBase(data, size); -} - -Status XlaAllocator::RegisterArgument(const Tensor* t) { - void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); - tensors_[data] = *t; - return Status::OK(); -} - -Status XlaAllocator::Deallocate(int device_ordinal, - gpu::DeviceMemoryBase* mem) { - if (mem->opaque() != nullptr) { - if (tensors_.erase(mem->opaque()) == 0) { - return tensorflow::errors::InvalidArgument("Unknown tensor address"); - } - } - return Status::OK(); -} - -Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, - DataType dtype, - const TensorShape& shape, - Tensor* out_tensor) const { - void* ptr = const_cast(buffer.opaque()); - auto it = tensors_.find(ptr); - if (it == tensors_.end()) { - return errors::InvalidArgument("Unknown tensor address"); - } - const Tensor& tensor = it->second; - - int64 output_size = DataTypeSize(dtype) * shape.num_elements(); - if (tensor.TotalBytes() == output_size) { - out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); - } else { - Tensor slice = tensor.Slice(0, output_size); - out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); - } - return Status::OK(); -} - XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) : OpKernel(ctx), device_type_(ctx->device_type()) { const NameAttrList* func; @@ -196,23 +92,6 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -std::vector SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { - std::vector snapshot(num_variables); - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { - Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - if (LookupResource(ctx, handle, &variable).ok()) { - tf_shared_lock lock(*variable->mu()); - snapshot[i].name = handle.name(); - snapshot[i].present = true; - snapshot[i].value = *variable->tensor(); - } - } - return snapshot; -} - void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOp::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); @@ -235,22 +114,39 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); + const XlaDevice::Metadata* metadata; + Status s = XlaDevice::GetMetadata(ctx, &metadata); + bool allocate_xla_tensors = s.ok(); + // Get the platform_id_ for XLA_* devices. if (platform_id_ == nullptr) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { platform_id_ = metadata->platform()->id(); } } - std::vector variables = + std::map variables = SnapshotResourceVariables(ctx, num_resource_args_); xla::LocalClient* client = static_cast(cache->client()); - // Builds an XLA allocator for the device. - XlaAllocator xla_allocator(client->platform(), ctx); + XlaAllocator local_xla_allocator(client->backend().platform(), + ctx->device()->GetAllocator({})); + xla::DeviceMemoryAllocator* xla_allocator; + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which local_xla_allocator above uses) as on an XlaDevice, this is a + // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a + // real allocator to allocate real buffers. + if (allocate_xla_tensors) { + xla_allocator = client->backend().memory_allocator(); + } else { + xla_allocator = &local_xla_allocator; + } XlaCompiler::Options options; options.client = client; @@ -258,150 +154,45 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); - options.device_allocator = &xla_allocator; + options.device_allocator = xla_allocator; + // TODO(b/77671268): We don't set variable_representation_shape_fn here. This + // is restricted to Variables, but we need something like this to apply to + // normal Tensors too. const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, + std::map constant_args; + for (int i = 0; i < num_constant_args_; ++i) { + constant_args.insert({i, ctx->input(i)}); + } + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, variables, ctx, &kernel, &executable, /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; - std::unique_ptr output; - // Build xla::ShapedBuffers that point directly to the Tensor buffers. - std::vector> arg_buffers; - arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); - arg_buffers.resize(kernel->xla_input_shapes.size()); - std::vector arg_ptrs(arg_buffers.size()); - - const int first_variable_arg = ctx->num_inputs() - num_resource_args_; - // Pass remaining parameters. - const Tensor* t; - for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { - int arg_num = kernel->input_mapping[i]; - const xla::Shape& shape = kernel->xla_input_shapes[i]; - if (arg_num >= first_variable_arg) { - t = &(variables[arg_num - first_variable_arg].value); - } else { - t = &(ctx->input(arg_num)); - } - - gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( - const_cast(t->tensor_data().data()), t->tensor_data().size()); - - const xla::Shape on_device_shape = - client->backend().transfer_manager()->HostShapeToDeviceShape(shape); - CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) - << "On-device shape " - << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) - << " not the same as on-host shape " - << xla::ShapeUtil::HumanStringWithLayout(shape); - arg_buffers[i] = xla::MakeUnique( - /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), - client->default_device_ordinal()); - arg_buffers[i]->set_buffer(dmem, /*index=*/{}); - arg_ptrs[i] = arg_buffers[i].get(); - - OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); - } + XlaComputationLaunchContext launch_context( + num_resource_args_, client, xla_allocator, allocate_xla_tensors); + launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); - run_options.set_allocator(&xla_allocator); + run_options.set_allocator(xla_allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(ctx->step_id()); Env* env = Env::Default(); auto start_time = env->NowMicros(); - auto run_result = executable->Run(arg_ptrs, run_options); + + auto run_result = executable->Run(launch_context.arguments(), run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - // Computation output should always be a tuple. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); - } - CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); - - // Copy XLA results to the OpOutputList. - int output_num = 0; - for (int i = 0; i < ctx->num_outputs(); ++i) { - if (kernel->outputs[i].is_constant) { - // Output is a constant. - const Tensor& const_tensor = kernel->outputs[i].constant_value; - const size_t total_bytes = const_tensor.TotalBytes(); - if (stream && total_bytes > 0) { - // Copy host -> device. (Empty tensors don't have backing buffers.) - VLOG(1) << "Constant output tensor on device"; - Tensor* output_tensor; - TF_CHECK_OK( - ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - - const void* src_ptr = DMAHelper::base(&const_tensor); - void* dst_ptr = DMAHelper::base(output_tensor); - gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); - stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); - } else { - // No copy required. - ctx->set_output(i, const_tensor); - } - } else { - const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - Tensor output_tensor; - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( - buffer, ctx->expected_output_dtype(i), shape, - &output_tensor)); - ctx->set_output(i, output_tensor); - ++output_num; - } - - if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DebugString(); - } - } - - // Apply variable updates, if any. - VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->resource_updates.size(); ++i) { - const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); - - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); - - Var* variable = nullptr; - // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not - // a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - - core::ScopedUnref s(variable); - - mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); - - // Looks up the owning Tensor by buffer address. - OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, - variable->tensor())); - ++output_num; - } - + launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 47fd912b12abbbe876e933ab57f6f586fd299909..c6cc0986af0300c51283d432c671e92a1e4d8145 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,14 +26,6 @@ limitations under the License. namespace tensorflow { -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back -// resource variables since concurrent updates may modify the shape, and it is -// important that the shapes used for compilation match the true shapes of the -// buffers. -std::vector SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); - // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 4491dd6ac8f2b84f341162eb469cc8194f817c9a..5d211f4d733d8d807426e62dd116092799184f35 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -52,16 +52,14 @@ cc_library( ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", +cc_library( + name = "xla_device_flags", + srcs = ["xla_device_flags.cc"], + hdrs = ["xla_device_flags.h"], + deps = + [ + "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", ], - ), - visibility = ["//tensorflow:__subpackages__"], ) diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 4bc209b7ecf499d82e7567f7eff12b17cefa9863..7277a1d1f8ad5fa045645ead839ab9efa01e89c7 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -40,6 +40,8 @@ static void AllocateFlags() { flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; flags->tf_xla_cpu_global_jit = false; + flags->tf_xla_clustering_fuel = std::numeric_limits::max(); + flags->tf_xla_fusion_only = false; flag_list = new std::vector( {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, "Control compilation of operators into XLA computations on CPU and " @@ -55,7 +57,13 @@ static void AllocateFlags() { Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, "Dump graphs during XLA compilation."), Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, - "Enables global JIT compilation for CPU via SessionOptions.")}); + "Enables global JIT compilation for CPU via SessionOptions."), + Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel, + "Places an artificial limit on the number of ops marked as " + "eligible for clustering."), + Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only, + "enable fusion of element-wise operations only using XLA when " + "global_jit_level is ON*.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index e1ccd7ddb8706ca445b6811ca1fec369af7cd5d5..2affda6ab4e0fbad32a246744fa5b38aeb629c1b 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -48,6 +48,13 @@ typedef struct { bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU // via SessionOptions. + int64 tf_xla_clustering_fuel; // "Compiler fuel" for clustering. Only this + // many ops will be marked as eligible for + // clustering. + bool tf_xla_fusion_only; // This flag is effective only when global_jit_level + // is set to ON* and overrides its behavior. If + // true, enable fusion of element-wise operations + // only using XLA. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..1bb2fce2dbad5bffce2e33b665b7222090d0855a --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.cc @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legacy flags for the XLA bridge's xla_device module. + +#include +#include + +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// Pointers to the parsed value of the flags and flag descriptors, initialized +// via flags_init. +static XlaDeviceFlags* flags; +static std::vector* flag_list; +static std::once_flag flags_init; + +// Allocate *flags. Called via call_once(&flags_init,...). +static void AllocateFlags() { + flags = new XlaDeviceFlags; + flags->tf_xla_compile_on_demand = false; + flag_list = new std::vector({ + Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand, + "Switch a device into 'on-demand' mode, where instead of " + "autoclustering ops are compiled one by one just-in-time."), + }); + xla::legacy_flags::ParseFlagsFromEnv(*flag_list); +} + +// Return a pointer to the XlaDeviceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +XlaDeviceFlags* GetXlaDeviceFlags() { + std::call_once(flags_init, &AllocateFlags); + return flags; +} + +} // namespace legacy_flags +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..27b22121ac1e089bd5d5a494e1e3fb60b05bc76d --- /dev/null +++ b/tensorflow/compiler/jit/legacy_flags/xla_device_flags.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ + +// Legacy flags for the XLA bridge's xla_device module. + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace legacy_flags { + +// The values of flags associated with the XLA bridge's +// xla_device module. +typedef struct { + // Switch the CPU device into "on-demand" mode, where instead of + // autoclustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; +} XlaDeviceFlags; + +// Return a pointer to the XlaDeviceFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. +XlaDeviceFlags* GetXlaDeviceFlags(); + +} // namespace legacy_flags +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index a0211acbbe9eec77d30c7d14293650de8826f41c..8e2ee0f1d71bc17b4c12c792c38002af4f9eb5eb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/version.h" @@ -50,6 +51,15 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // is really a kind of function call and will be handled by // IsCompilableCall(). if (node.type_string() == "SymbolicGradient") return false; + if (node.type_string() == "Const") { + // Skip Const op with type DT_STRING, since XLA doesn't support it, but the + // registered Const KernelDef says that it does, to support no-op Assert for + // tfcompile. + const AttrValue* attr = node.attrs().Find("dtype"); + if (attr != nullptr && attr->type() == DT_STRING) { + return false; + } + } return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } @@ -174,10 +184,164 @@ bool HasResourceInputOrOutput(const Node& node) { } struct NodeCompare { - bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } + bool operator()(const Node* a, const Node* b) const { + return a->id() < b->id(); + } }; using OrderedNodeSet = std::set; +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +// +// TODO(hpucha): Consider a black list instead of a white list as +// implemented below. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/batchtospace_op.cc + "BatchToSpace", "BatchToSpaceND", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/categorical_op.cc + "Multinomial" /* (Rng ops are disabled on GPU backend currently)*/, + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/cross_op.cc + "Cross", + // tf2xla/kernels/depthtospace_op.cc + "DepthToSpace", + // tf2xla/kernels/diag_op.cc + "Diag", "DiagPart", "MatrixDiag", "MatrixDiagPart", + // tf2xla/kernels/dynamic_stitch_op.cc + "DynamicStitch", "ParallelDynamicStitch", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fake_quantize_ops.cc + "FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxArgsGradient", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxVarsGradient" /*(Reduce)*/, + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/gather_op.cc + "Gather", "GatherV2", "GatherNd", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", "StopGradient", + "Snapshot", + // tf2xla/kernels/image_ops.cc + "RGBToHSV", "HSVToRGB", "AdjustContrastv2" /*(Reduce)*/, + "AdjustSaturation", "AdjustHue", + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/l2loss_op.cc + "L2Loss" /*(Reduce)*/, + // tf2xla/kernels/lrn_ops.cc (ReduceWindow) + "LRN", "LRNGrad", + // tf2xla/kernels/matrix_band_part_op.cc + "MatrixBandPart", + // tf2xla/kernels/matrix_set_diag_op.cc + "MatrixSetDiag", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/no_op.cc + "NoOp", "ControlTrigger", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/pooling_ops.cc + "MaxPool", "MaxPoolV2", "MaxPool3D", "AvgPool", + "AvgPool3D", /*(all the pooling ops use ReduceWindow)*/ + "MaxPoolGrad", "MaxPoolGradV2", "MaxPool3DGrad", "AvgPoolGrad", + "AvgPool3DGrad", + // tf2xla/kernels/quantize_and_dequantize_op.cc (Reduce) + "QuantizeAndDequantizeV2", + // tf2xla/kernels/random_ops.cc (Rng ops are disabled on GPU backend + // currently) + "RandomUniform", "RandomUniformInt", "RandomStandardNormal", + "TruncatedNormal", + // tf2xla/kernels/reduction_ops.cc (Reduce) + "Sum", "Prod", "Min", "Max", "Mean", "All", "Any", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/scan_ops.cc (ReduceWindow) + "Cumsum", "Cumprod", + // tf2xla/kernels/scatter_nd_op.cc (Reduce) + "ScatterNd", + // tf2xla/kernels/segment_reduction_ops.cc (Reduce) + "UnsortedSegmentSum", + // tf2xla/kernels/select_op.cc + "Select", + // tf2xla/kernels/sequence_ops.cc + "Range", "LinSpace", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/softmax_op.cc (Reduce) + "Softmax", "LogSoftmax", "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + // tf2xla/kernels/spacetobatch_op.cc + "SpaceToBatchND", "SpaceToBatch", + // tf2xla/kernels/spacetodepth_op.cc + "SpaceToDepth", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/stack_ops.cc + "StackV2", "StackPushV2", "StackPopV2", "StackCloseV2", + // tf2xla/kernels/stateless_random_ops.cc (Rng ops are disabled on + // GPU + // backend currently) + "StatelessRandomUniform", + "StatelessRandomNormal" + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", + "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/training_ops.cc + "ResourceApplyGradientDescent", "ResourceApplyMomentum", + "ResourceApplyAdagrad", "ResourceApplyAdam", "ResourceApplyRMSProp", + "ResourceApplyFtrl", "ResourceApplyFtrlV2", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, @@ -189,7 +353,27 @@ Status FindCompilationCandidates( FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + int64& fuel = + legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; + + // Iterate over nodes in sorted order so that compiler fuel is deterministic. + // We can't simply pass op_nodes().begin() and op_nodes().end to the + // std::vector constructor because they're not proper iterators, with + // iterator_traits defined and so on. + std::vector sorted_nodes; for (Node* node : graph.op_nodes()) { + sorted_nodes.push_back(node); + } + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + + for (Node* node : sorted_nodes) { + VLOG(2) << "Fuel: " << fuel; + if (fuel <= 0) { + VLOG(2) + << "Hit fuel limit; not marking any remaining ops as clusterable."; + break; + } + VLOG(2) << "FindCompilationCandidates(): Processing " << node->DebugString(); @@ -234,7 +418,9 @@ Status FindCompilationCandidates( continue; } candidates->insert(node); + --fuel; } + VLOG(2) << "candidates->size() = " << candidates->size(); return Status::OK(); } @@ -256,6 +442,9 @@ string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, } auto node_name = [&cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } auto* node = graph.FindNodeId(node_id); if (node == nullptr) { return string("(null)"); @@ -314,10 +503,13 @@ Status MarkForCompilationPass::Run( static_cast(flags->tf_xla_auto_jit); } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + bool fusion_only = flags->tf_xla_fusion_only; + VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; + VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld]( const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), @@ -340,6 +532,11 @@ Status MarkForCompilationPass::Run( status = fld->GetAttr(*node, kXlaCompileAttr, &compile); if (status.ok()) return compile; + // Check for fusable ops only if requested. + if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + return false; + } + // Otherwise use the value of global_jit_level. // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag @@ -544,11 +741,15 @@ Status MarkForCompilationPass::RunImpl( } } - // Count the number of elements in each cluster. - std::vector cluster_sizes(graph->num_node_ids()); + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph->num_node_ids()); for (const Node* n : compilation_candidates) { int cluster = clusters[n->id()].Get().representative; - cluster_sizes[cluster]++; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } } // Names for each cluster. @@ -581,9 +782,12 @@ Status MarkForCompilationPass::RunImpl( const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); - // Or compile if this is a cluster of >= min_cluster_size compilable - // operators. - if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || + // Compile if this is a cluster of >= min_cluster_size compilable operators. + // Also, always compile if the operator is placed on a device that requires + // compilation, or if it contains at least one op that is marked for + // compilation that is not an Identity op. + if (effective_cluster_sizes[cluster] >= min_cluster_size || + (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) || registration->requires_compilation) { string& name = cluster_names[cluster]; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 1a8858cccef623185709ab5dc2187a313dd130f7..703d8825d74ced8d4d69c31ccd730adc89a8bffe 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,6 +29,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -137,7 +140,7 @@ TEST(XlaCompilationTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } -TEST(XlaCompilationTest, UnsupportedTypes) { +TEST(XlaCompilationTest, Complex128Unsupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; { @@ -157,6 +160,27 @@ TEST(XlaCompilationTest, UnsupportedTypes) { EXPECT_TRUE(clusters.empty()); } +TEST(XlaCompilationTest, HalfSupported) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Tensor t(DT_HALF, TensorShape()); + t.scalar()() = static_cast(0.0f); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_HALF) + .WithAttr("value", t)); + Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); + ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_FALSE(clusters.empty()); +} + TEST(XlaCompilationTest, ConcatWithConstArg) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -519,11 +543,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.ToString()) - .contains("Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(str_util::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { @@ -553,5 +577,61 @@ TEST(XlaCompilationTest, Retval) { EXPECT_EQ(clusters["A"], clusters["B"]); } +TEST(XlaCompilationTest, DontCountIdentityOps) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + { + auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); + auto b = ops::Identity(root.WithOpName("B"), a); + auto c = ops::Identity(root.WithOpName("C"), b); + auto r = ops::_Retval(root.WithOpName("R"), c, 0); + } + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_TRUE(clusters.empty()); +} + +TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + { + auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); + auto b = ops::Identity(root.WithOpName("B"), a); + b.node()->AddAttr(kXlaCompileAttr, true); + auto r = ops::_Retval(root.WithOpName("R"), b, 0); + } + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_TRUE(clusters.empty()); +} + +TEST(XlaCompilationTest, ConstOp) { + // valid data type + { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + auto c = ops::Const(root.WithOpName("const"), 0.5f); + c.node()->AddAttr(kXlaCompileAttr, true); + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + EXPECT_EQ(1, GetClusters(*graph).size()); + } + + // invalid data type + { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + Scope root = Scope::NewRootScope().ExitOnError(); + auto c = ops::Const(root.WithOpName("const"), string("string")); + c.node()->AddAttr(kXlaCompileAttr, true); + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + EXPECT_TRUE(GetClusters(*graph).empty()); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index e5787ca4c8cff436e4404b8488970248b24a5eda..c9e46bc1475aed0e35a48765ad70eef4362e8281 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -17,17 +17,3 @@ cc_library( deps = ["//tensorflow/core:framework"], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/jit/producer_consumer_queue.h b/tensorflow/compiler/jit/producer_consumer_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..7c8c04152d2f3a0fd46711df24756b7e68b967ea --- /dev/null +++ b/tensorflow/compiler/jit/producer_consumer_queue.h @@ -0,0 +1,132 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ +#define TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ + +#include +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// A thread-safe, first-in-first-out queue. +template +class ProducerConsumerQueue { + public: + ProducerConsumerQueue() + : capacity_(std::numeric_limits::max()) {} + ~ProducerConsumerQueue() = default; + + // Wait until the queue is non-full, then append a copy of v. + void Put(const T &v); + + // Wait until the queue is non-empty, then remove and return the head value. + T Get(); + + // If the queue is non-empty, remove the head value, placing it in *pv, and + // return true; otherwise return false. + bool TryGet(T *pv); + + // Set the capacity of the queue; the queue is full whenever count() >= + // capacity(). The initial value is the maximum size_t. Requires size > 0. + void set_capacity(std::size_t size); + + // Return the capacity of the queue. + std::size_t capacity() const; + + // Return the number of elements in the queue. + std::size_t count() const; + + // Implementation details follow. Clients should ignore. + private: + mutable tensorflow::mutex mu_; // protects all fields below + tensorflow::condition_variable non_empty_ GUARDED_BY(mu_); + tensorflow::condition_variable non_full_ GUARDED_BY(mu_); + std::size_t capacity_ GUARDED_BY(mu_); + std::deque queue_ GUARDED_BY(mu_); + + TF_DISALLOW_COPY_AND_ASSIGN(ProducerConsumerQueue); +}; + +// ------------------------------------------------------ +// Implementation details follow. Clients should ignore. + +// Wait until the queue is non-full, then append a copy of v. +template +void ProducerConsumerQueue::Put(const T &v) { + mutex_lock lock(mu_); + while (queue_.size() >= capacity_) { + non_full_.wait(lock); + } + queue_.push_back(v); + non_empty_.notify_one(); +} + +// Wait until the queue is non-empty, then remove and return the head value. +template +T ProducerConsumerQueue::Get() { + mutex_lock lock(mu_); + while (queue_.empty()) { + non_empty_.wait(lock); + } + non_full_.notify_one(); + T result_value = queue_.front(); + queue_.pop_front(); + return result_value; +} + +// If the queue is non-empty, remove the head value, placing it in *pv, and +// return true; otherwise return false. +template +bool ProducerConsumerQueue::TryGet(T *pv) { + mutex_lock lock(mu_); + bool got_element = !queue_.empty(); + if (got_element) { + non_full_.notify_one(); + *pv = queue_.front(); + queue_.pop_front(); + } + return got_element; +} + +// Set the capacity of the queue; the queue is full whenever count() >= +// capacity(). The initial value is the maximum size_t. Requires size > 0. +template +void ProducerConsumerQueue::set_capacity(std::size_t size) { + mutex_lock lock(mu_); + CHECK_NE(size, 0); + capacity_ = size; + non_full_.notify_all(); +} + +// Return the capacity of the queue. +template +std::size_t ProducerConsumerQueue::capacity() const { + mutex_lock lock(mu_); + std::size_t max_elements = capacity_; + return max_elements; +} + +// Return the number of elements in the queue. +template +std::size_t ProducerConsumerQueue::count() const { + mutex_lock lock(mu_); + std::size_t num_elements = queue_.size(); + return num_elements; +} +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PRODUCER_CONSUMER_QUEUE_H_ diff --git a/tensorflow/compiler/jit/producer_consumer_queue_test.cc b/tensorflow/compiler/jit/producer_consumer_queue_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f61260c6e52756ee039829afdc7452f5f760c221 --- /dev/null +++ b/tensorflow/compiler/jit/producer_consumer_queue_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/producer_consumer_queue.h" + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +typedef ProducerConsumerQueue IntQueue; + +// Insert integers between low inclusive and high exclusive into q. +void PushRange(IntQueue *q, int low, int high) { + while (low != high) { + q->Put(low); + VLOG(2) << "Pushing " << low; + ++low; + } +} + +// Push the numbers between 0 and 999 inclusive from several threads in the +// pool. +void PushRanges(IntQueue *queue, thread::ThreadPool *pool) { + VLOG(1) << "Adding 20-36"; + pool->Schedule([queue] { PushRange(queue, 20, 36); }); + VLOG(1) << "Adding 7-20"; + pool->Schedule([queue] { PushRange(queue, 7, 20); }); + VLOG(1) << "Adding 36-501"; + pool->Schedule([queue] { PushRange(queue, 36, 501); }); + VLOG(1) << "Adding 501-1000"; + pool->Schedule([queue] { PushRange(queue, 501, 1000); }); + VLOG(1) << "Adding 0-5"; + pool->Schedule([queue] { PushRange(queue, 0, 5); }); + VLOG(1) << "Adding 5-7"; + pool->Schedule([queue] { PushRange(queue, 5, 7); }); +} + +// Pop elements from queue using Get(). Make sure that exactly elements +// were present and their values are all integers between 0 and high-1 +// inclusive. +void GetRange(IntQueue *queue, int high) { + VLOG(1) << "Testing Wait"; + std::vector results; + for (int i = 0; i != high; ++i) { + int r = queue->Get(); + VLOG(2) << "Waited and got " << r; + results.push_back(r); + } + CHECK_EQ(queue->count(), 0); + std::sort(results.begin(), results.end()); + for (int i = 0; i != high; ++i) { + CHECK(results[i] == i); + } +} + +// Pop elements from queue using TryGet(). Make sure that exactly +// elements were present and their values are all integers between 0 and high-1 +// inclusive. +void TryGetRange(IntQueue *queue, int high) { + std::vector results; + // Give up if we don't get all the elements back from the queue + // in 10 seconds. + int timeout = 10; + int r; + for (int i = 0; i != high; ++i) { + while (!queue->TryGet(&r)) { + if (!timeout--) { + LOG(FATAL) << "Can't find all elements in the queue"; + } + VLOG(1) << "Sleeping for a second..."; + sleep(1); + } + VLOG(2) << "Popped " << r; + results.push_back(r); + } + CHECK_EQ(queue->count(), 0); + CHECK(!queue->TryGet(&r)); + std::sort(results.begin(), results.end()); + for (int i = 0; i != high; ++i) { + CHECK_EQ(i, results[i]); + } +} + +const int kNumThreads = 15; + +TEST(ProducerConsumerQueue, GetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + PushRanges(&queue, &pool); + } + GetRange(&queue, 1000); +} + +TEST(ProducerConsumerQueue, TryGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + PushRanges(&queue, &pool); + } + TryGetRange(&queue, 1000); +} + +TEST(ProducerConsumerQueue, ParallelGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + pool.Schedule([&queue] { GetRange(&queue, 1000); }); + PushRanges(&queue, &pool); + } +} + +TEST(ProducerConsumerQueue, ParallelTryGetRange) { + IntQueue queue; + { + thread::ThreadPool pool(Env::Default(), "test", kNumThreads); + pool.Schedule([&queue] { TryGetRange(&queue, 1000); }); + PushRanges(&queue, &pool); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference_helpers.cc b/tensorflow/compiler/jit/shape_inference_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9cfa16526bc5d809942a35e86075b4ec6e88a59 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_helpers.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains helpers for use in shape inference. + +#include "tensorflow/compiler/jit/shape_inference_helpers.h" + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +Status BackEdgeHelper::Remove(Graph* graph) { + if (graph_ != nullptr) { + return errors::Internal("BackEdgeHelper duplicate call to Remove."); + } + graph_ = graph; + for (Node* n : graph_->nodes()) { + if (n->IsMerge()) { + for (const Edge* e : n->in_edges()) { + if (e->src()->IsNextIteration()) { + back_edges_.push_back( + BackEdge{e, e->src(), e->src_output(), e->dst(), e->dst_input()}); + } + } + } + } + for (const BackEdge& be : back_edges_) { + graph_->RemoveEdge(be.edge); + } + return Status::OK(); +} + +const std::vector& BackEdgeHelper::RemovedEdges() + const { + return back_edges_; +} + +Status BackEdgeHelper::Replace() { + if (graph_ == nullptr) { + return errors::Internal("BackEdgeHelper Replace called before Remove."); + } + if (replaced_) { + return errors::Internal("BackEdgeHelper Replace called more than once."); + } + replaced_ = true; + for (const BackEdge& be : back_edges_) { + graph_->AddEdge(be.src, be.src_output, be.dst, be.dst_input); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference_helpers.h b/tensorflow/compiler/jit/shape_inference_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..2f053c9a45dd47ca1b056634d2248d6181e77d68 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_helpers.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Helper class to temporarily remove, then replace, the back edges in a +// graph. Simple algorithms for shape inference don't work with cycles, and this +// class can be used to remove cycles before running inference and replace them +// after. Correct usage requires exactly one call to Remove(), followed by any +// number of calls to RemovedEdges() and at most one call to Replace(). The call +// to Replace() is optional if the graph will be discarded without being +// executed, e.g., if it is being used purely for a shape inference pass. +class BackEdgeHelper { + public: + struct BackEdge { + const Edge* edge; + Node* src; + int src_output; + Node* dst; + int dst_input; + }; + + BackEdgeHelper() = default; + // Disallows copy and assign. + BackEdgeHelper(const BackEdgeHelper& other) = delete; + BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete; + + // Temporarily removes all the back edges in graph. + Status Remove(Graph* graph); + + // Gets the list of removed edges. + const std::vector& RemovedEdges() const; + + // Replaces the back edges removed by a prior call to Remove. + Status Replace(); + + private: + Graph* graph_ = nullptr; // not owned + std::vector back_edges_; + // Set once Replace has been called. + bool replaced_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6d854a920eb0b4c01b09024ceaef5035e847d392..6430975335f5eef5b53c80213e6090ffd6166a91 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -92,38 +92,30 @@ uint64 XlaCompilationCache::Signature::Hash::operator()( } Status XlaCompilationCache::BuildSignature( - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, OpKernelContext* ctx, + const NameAttrList& function, const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, Signature* signature) { signature->name = Canonicalize(function.name(), AttrSlice(&function.attr())); - signature->arg_values.resize(num_constant_args); - - signature->arg_types.reserve(ctx->num_inputs() - num_constant_args); - - // Inputs are in the order: constants, non-constants, resource variables. - int input_num = 0; - // Use the values of compile time constants in the signature-> - while (input_num < num_constant_args) { - signature->arg_values[input_num] = ctx->input(input_num); - ++input_num; - } - // Add the types and shapes of the remaining arguments. - while (input_num < ctx->num_inputs() - variable_args.size()) { - signature->arg_types.emplace_back(ctx->input_dtype(input_num), - ctx->input(input_num).shape()); - ++input_num; - } - // For variable signatures, use the type and shape of the variable's - // current value. - for (const OptionalTensor& variable : variable_args) { - TF_RET_CHECK(input_num < ctx->num_inputs()); - if (variable.present) { - signature->arg_types.emplace_back(variable.value.dtype(), - variable.value.shape()); + signature->arg_values.reserve(constant_args.size()); + + signature->arg_types.reserve(ctx->num_inputs() - constant_args.size()); + + for (int i = 0; i < ctx->num_inputs(); ++i) { + if (constant_args.count(i) > 0) { + // Use the values of compile time constants in the signature. + signature->arg_values.push_back(constant_args.at(i)); + } else if (variable_args.count(i) > 0) { + const OptionalTensor& variable = variable_args.at(i); + if (variable.present) { + signature->arg_types.emplace_back(variable.value.dtype(), + variable.value.shape()); + } else { + signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + } } else { - signature->arg_types.emplace_back(DT_INVALID, TensorShape()); + signature->arg_types.emplace_back(ctx->input_dtype(i), + ctx->input(i).shape()); } - ++input_num; } return Status::OK(); } @@ -131,74 +123,58 @@ Status XlaCompilationCache::BuildSignature( namespace { // Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. The first `num_constant_args` arguments must be host-memory Tensors. -Status BuildArguments(int num_constant_args, - const std::vector& variable_args, +// op. +Status BuildArguments(const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, std::vector* args) { args->resize(ctx->num_inputs()); - int input_num = 0; - - // Handles compile-time constants. - TF_RET_CHECK(num_constant_args <= ctx->num_inputs()); - while (input_num < num_constant_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); - XlaCompiler::Argument& arg = (*args)[input_num]; - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input.dtype(); - arg.shape = input.shape(); - arg.constant_value = input; - ++input_num; - } - - // Handles the non-constant arguments. - int num_variable_args = variable_args.size(); - int num_nonconst_args = - ctx->num_inputs() - num_variable_args - num_constant_args; - TF_RET_CHECK(num_nonconst_args >= 0); - while (input_num < num_constant_args + num_nonconst_args) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() != DT_RESOURCE); + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { XlaCompiler::Argument& arg = (*args)[input_num]; - if (input.NumElements() > 0) { - arg.kind = XlaCompiler::Argument::kParameter; - } else { + if (constant_args.count(input_num) > 0) { + // Handles compile-time constants. + const Tensor& input = constant_args.at(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input.dtype(); + arg.shape = input.shape(); arg.constant_value = input; - } - arg.type = input.dtype(); - arg.shape = input.shape(); - ++input_num; - } - - // Handles resource variables. - TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs()); - for (int variable_id = 0; variable_id < num_variable_args; ++variable_id) { - const Tensor& input = ctx->input(input_num); - TF_RET_CHECK(input.dtype() == DT_RESOURCE); - - XlaCompiler::Argument& arg = (*args)[input_num]; - - arg.name = variable_args[variable_id].name; - arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = XlaResource::kVariable; - if (variable_args[variable_id].present) { - const Tensor& value = variable_args[variable_id].value; - arg.type = value.dtype(); - arg.shape = value.shape(); - arg.initialized = true; + } else if (variable_args.count(input_num) == 0) { + // Handles the non-constant arguments. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() != DT_RESOURCE); + if (input.NumElements() > 0) { + arg.kind = XlaCompiler::Argument::kParameter; + } else { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = input; + } + arg.type = input.dtype(); + arg.shape = input.shape(); } else { - // The values of uninitialized variables are not passed as inputs, since - // they are meaningless. However, it is legal to assign to a resource - // variable for the first time inside the XLA computation, so we do permit - // uninitialized variables. - arg.initialized = false; - arg.type = DT_INVALID; - arg.shape = TensorShape(); + // Handles resource variables. + const Tensor& input = ctx->input(input_num); + TF_RET_CHECK(input.dtype() == DT_RESOURCE); + const OptionalTensor& variable = variable_args.at(input_num); + arg.name = variable.name; + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = XlaResource::kVariable; + if (variable.present) { + const Tensor& value = variable.value; + arg.type = value.dtype(); + arg.shape = value.shape(); + arg.initialized = true; + } else { + // The values of uninitialized variables are not passed as inputs, since + // they are meaningless. However, it is legal to assign to a resource + // variable for the first time inside the XLA computation, so we do + // permit uninitialized variables. + arg.initialized = false; + arg.type = DT_INVALID; + arg.shape = TensorShape(); + } } - ++input_num; } return Status::OK(); @@ -233,16 +209,43 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - int num_constant_args, const std::vector& variable_args, - OpKernelContext* ctx, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options) { + return CompileImpl(options, function, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, false); +} + +Status XlaCompilationCache::CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { + const NodeDef& def = ctx->op_kernel().def(); + NameAttrList name; + name.set_name(def.op()); + *name.mutable_attr() = def.attr(); + return CompileImpl(options, name, constant_args, variable_args, ctx, + compilation_result, executable, compile_options, true); +} + +Status XlaCompilationCache::CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() - << " num_constant_args=" << num_constant_args + << " num_constant_args=" << constant_args.size() << " num_variable_args=" << variable_args.size(); for (int i = 0; i < ctx->num_inputs(); i++) { TensorShape shape = ctx->input(i).shape(); @@ -250,10 +253,12 @@ Status XlaCompilationCache::Compile( << " present=" << ctx->has_input(i) << " shape=" << shape.DebugString(); } - for (const OptionalTensor& variable : variable_args) { + for (auto& iterator : variable_args) { + const OptionalTensor& variable = iterator.second; VLOG(2) << "variable present=" << variable.present << " type=" << DataTypeString(variable.value.dtype()) - << " shape=" << variable.value.shape().DebugString(); + << " shape=" << variable.value.shape().DebugString() + << " TF arg= " << iterator.first; } VLOG(2) << "num_outputs = " << ctx->num_outputs(); for (int i = 0; i < ctx->num_outputs(); i++) { @@ -261,11 +266,12 @@ Status XlaCompilationCache::Compile( } } - TF_RET_CHECK(num_constant_args + variable_args.size() <= ctx->num_inputs()); + TF_RET_CHECK(constant_args.size() + variable_args.size() <= + ctx->num_inputs()); Signature signature; - TF_RETURN_IF_ERROR(BuildSignature(function, num_constant_args, variable_args, - ctx, &signature)); + TF_RETURN_IF_ERROR( + BuildSignature(function, constant_args, variable_args, ctx, &signature)); VLOG(2) << "Signature: " << SignatureDebugString(signature); // The outer lock protects the existence of the cache entry. It does not @@ -292,13 +298,20 @@ Status XlaCompilationCache::Compile( // a long time.) std::vector args; TF_RETURN_IF_ERROR( - BuildArguments(num_constant_args, variable_args, ctx, &args)); + BuildArguments(constant_args, variable_args, ctx, &args)); XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + + if (compile_single_op) { + entry->compilation_status = compiler.CompileSingleOp( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + signature.name, ctx, args, &entry->compilation_result); + } else { + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); + } } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 0858020716fcf4763e42dc0699ad22cfda756942..be1043d8c3fc0573922837e541615114a6d7a1a5 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -52,29 +52,52 @@ class XlaCompilationCache : public ResourceBase { // Compiles a function into a XlaCompiler::CompilationResult that can be used // to execute an XLA Computation. Compilation results are cached. // `function` is the name of a Tensorflow function to compile. - // `num_constant_args` is the number of compile-time constant arguments to - // `function`. `variable_args` is a snapshot of the current values of the + // `constant_args` is a map of tensorflow argument number to its constant + // value. + // `variable_args` is a snapshot of the current values of the // resource variable arguments to `function`; uninitialized variables are // represented by an absent OptionalTensor. // The result of compilation is written to `*compilation_result`, which must // be non-null. If `executable` is non-null, also builds an - // xla::LocalExecutable and sets `executable to point to it. The resulting + // xla::LocalExecutable and sets `executable` to point to it. The resulting // executable pointer may be null if the computation has no non-constant // outputs. Status Compile(const XlaCompiler::Options& options, - const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, const XlaCompiler::CompileOptions* compile_options); + // As above, but calls XlaCompiler::CompileSingleOp instead of + // XlaCompiler::CompileFunction. + Status CompileSingleOp( + const XlaCompiler::Options& options, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); + xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } string DebugString() override; private: + // Common implementation of Compile and CompileSingleOp. + Status CompileImpl(const XlaCompiler::Options& options, + const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, + OpKernelContext* ctx, + const XlaCompiler::CompilationResult** compilation_result, + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options, + bool compile_single_op); + // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. Status BuildExecutable(const XlaCompiler::Options& options, @@ -104,8 +127,9 @@ class XlaCompilationCache : public ResourceBase { static string SignatureDebugString(const Signature& sig); // Builds the signature for a compilation. - Status BuildSignature(const NameAttrList& function, int num_constant_args, - const std::vector& variable_args, + Status BuildSignature(const NameAttrList& function, + const std::map& constant_args, + const std::map& variable_args, OpKernelContext* ctx, Signature* signature); // The value associated with a cache entry. diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..682d6ea8ccc4a54912ccad4666cf0a7a03a7a698 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -0,0 +1,175 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines the XlaCompileOnDemandOp. + +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +namespace { +std::map GetVariables(OpKernelContext* ctx) { + std::map variables; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dtype() == DT_RESOURCE) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& optional = variables[i]; + optional.name = handle.name(); + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + optional.present = true; + optional.value = *variable->tensor(); + } + } + } + return variables; +} +} // namespace + +Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, + const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable) { + std::map variables = GetVariables(ctx); + int64 num_resource_args = variables.size(); + + xla::LocalClient* client = metadata.client(); + + // Builds an XLA allocator for the device. + XlaComputationLaunchContext launch_context( + num_resource_args, client, client->backend().memory_allocator(), true); + + launch_context.PopulateInputs(ctx, result, variables); + + perftools::gputools::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + TF_RET_CHECK(stream); + + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(client->backend().memory_allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + TF_RETURN_IF_ERROR(run_result.status()); + + launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + return Status::OK(); +} + +bool XlaCompileOnDemandOp::MustArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // TODO(jmolloy): This could be expensive, so memoize. + auto* constant_inputs = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( + op_kernel->def().op()); + CHECK(constant_inputs); + std::set constant_input_indices; + for (const auto& name : *constant_inputs) { + int start, stop; + TF_CHECK_OK(op_kernel->InputRange(name, &start, &stop)); + for (int i = start; i < stop; ++i) { + constant_input_indices.insert(i); + } + } + return constant_input_indices.count(argument_idx) > 0; +} + +bool XlaCompileOnDemandOp::ShouldArgumentBeConstant(const OpKernel* op_kernel, + int64 argument_idx) { + // Right now we only create kConstant arguments when absolutely required, but + // there may be benefit in eagerly constant-folding a larger subset of + // arguments in the future. + return MustArgumentBeConstant(op_kernel, argument_idx); +} + +Status XlaCompileOnDemandOp::Compile( + OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable) { + std::map constant_arguments; + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& device_tensor = ctx->input(i); + if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { + if (xla_tensor->has_host_tensor() && + ShouldArgumentBeConstant(&ctx->op_kernel(), i)) { + constant_arguments[i] = xla_tensor->host_tensor(); + } + } + if (constant_arguments.count(i) == 0 && + MustArgumentBeConstant(&ctx->op_kernel(), i)) { + // Slow path; the argument is not available as a host constant so we must + // fetch it synchronously. + Tensor host_tensor; + AllocatorAttributes attrs; + attrs.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp( + device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs)); + Notification n; + ctx->op_device_context()->CopyDeviceTensorToCPU( + &device_tensor, "ConstantArgument", + reinterpret_cast(ctx->device()), &host_tensor, + [&](Status status) { n.Notify(); }); + n.WaitForNotification(); + constant_arguments[i] = host_tensor; + } + } + + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + CHECK(rm); + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + *cache = new XlaCompilationCache(metadata.client(), + metadata.jit_device_type()); + return Status::OK(); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + XlaCompiler::Options options; + DeviceType device_type = metadata.jit_device_type(); + options.device_type = &device_type; + options.client = metadata.client(); + options.flib_def = + new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + + std::map variable_args = GetVariables(ctx); + return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, + result, executable, + /*compile_options=*/nullptr); +} + +void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { + const XlaCompiler::CompilationResult* result; + xla::LocalExecutable* executable; + const XlaDevice::Metadata* metadata; + OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); + OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable)); + OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable)); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h new file mode 100644 index 0000000000000000000000000000000000000000..23c6f3903f841a6c39104983c6f7f409757a7319 --- /dev/null +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaCompileOnDemandOp is an OpKernel that, when its Compute method is +// called, will generate an xla::Computation and run it asynchronously. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ + +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// An OpKernel that compiles an op to an XLA computation and runs it. Unlike +// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// vanilla TensorFlow op as long as the bridge supports it. +// +// Importantly _XlaLaunch assumes all input and output tensors are on the host, +// whereas XlacompileOnDemandOp works with tensors in device memory. +class XlaCompileOnDemandOp : public OpKernel { + public: + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override; + + private: + XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); + bool ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + bool MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx); + Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable); + Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + const XlaCompiler::CompilationResult* result, + xla::LocalExecutable* executable); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index e238252751e677eb947f6df03e3b2f2e948ffe19..bc07dbd7bdf005fde781f7a1e6775080e363abfb 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -17,6 +17,8 @@ limitations under the License. // operators using XLA via the XLA "Host" (CPU) backend. #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -34,14 +36,24 @@ class XlaCpuDeviceFactory : public DeviceFactory { Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); + bool compile_on_demand = flags->tf_xla_compile_on_demand; + + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_CPU_XLA_JIT; + registration.requires_compilation = !compile_on_demand; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT); (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } @@ -50,8 +62,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index d4d8fe1c1d575b4e35d624621cc709e3a16569d5..12f471735f68394a3079541e9ac8532e329bd694 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -99,7 +100,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(backend, device_ordinal); + xla::MakeUnique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -108,21 +109,15 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, bool register_device_for_compilation, - std::unique_ptr* device) { + const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - if (register_device_for_compilation) { - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); - } + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { @@ -137,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( device->reset(new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie())); + platform.ValueOrDie(), transfer_as_literal)); return Status::OK(); } @@ -162,6 +157,7 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { + *metadata = nullptr; XlaDevice* xla_device = dynamic_cast(ctx->device()->UnderlyingDevice()); if (xla_device == nullptr) { @@ -177,13 +173,15 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { XlaDevice::XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform) + const DeviceType& jit_device_name, se::Platform* platform, + bool transfer_as_literal) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), - platform_(platform) {} + platform_(platform), + transfer_as_literal_(transfer_as_literal) {} XlaDevice::~XlaDevice() {} @@ -225,7 +223,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - auto ctx = new XlaDeviceContext(stream); + // Call GetAllocator for the side-effect of ensuring the allocator and + // XlaTensorInfoManager is created. + (void)GetAllocator({}); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -273,7 +274,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream); + XlaTransferManager manager(stream, client(), transfer_as_literal_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; @@ -288,19 +289,23 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { + // Any op assigned to the device that isn't rewritten by the graph rewriter + // gets executed by a n XlaCompileOnDemandOp, which compiles it and executes + // it just-in-time. + kernel_factory::OpKernelRegistrar::Factory factory = + [](OpKernelConstruction* context) -> OpKernel* { + return new XlaCompileOnDemandOp(context); + }; XlaOpRegistry::RegisterCompilationKernels(); XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations; - auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { - return new XlaDeviceDummyOp(context); - }; for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( jit_device, /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( - new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp", - dummy_factory)); + new kernel_factory::OpKernelRegistrar(def, "XlaCompileOnDemandOp", + factory)); } return registrations; } diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d2ec38293c429f04f088bf3726ba97eb4e4b0dba..4fe7dd8c9fa9eb954804555e9615160dc4bc3e8a 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -26,6 +26,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -71,15 +73,20 @@ class XlaDevice : public LocalDevice { // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. + // 'transfer_as_literal' is true if device<->host transfers must be done using + // XLA's TransferLiteral{To,From}Device interface. If false, we can use + // ThenMemcpy instead. static Status Create(const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, - bool register_device_for_compilation, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - ::perftools::gputools::Platform* platform); + ::perftools::gputools::Platform* platform, + bool transfer_as_literal); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -104,7 +111,7 @@ class XlaDevice : public LocalDevice { // Which hardware device in the client's platform this XlaDevice controls. const int device_ordinal_; // The name of the device that is used to compile Ops for this XlaDevice. - const DeviceType& jit_device_name_; + DeviceType jit_device_name_; // Memory allocator associated with this device. Allocator* xla_allocator_; // Not owned. ::perftools::gputools::Platform* platform_; // Not owned. @@ -113,9 +120,12 @@ class XlaDevice : public LocalDevice { // copying back and forth between CPU and the device, and // computations enqueued by XLA. xla::Backend::StreamPtr stream_; + // Must we use XLA's transfer manager for correct host<->device transfers? if + // false, we can use ThenMemcpy() instead. + bool transfer_as_literal_; }; -// Builds dummy OpKernel registrations on 'device' for the JIT operators +// Builds OpKernel registrations on 'device' for the JIT operators // registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations // object that encapsulates the kernel registrations. struct XlaDeviceOpRegistrations { diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index c936222f32056e92efced82d5adb3a96c8041a17..43eb164012610723214cf39360698010c9dbdbd4 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/mem.h" @@ -26,33 +28,58 @@ namespace se = ::perftools::gputools; namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. -XlaDeviceAllocator::XlaDeviceAllocator(const xla::Backend* backend, - int device_ordinal) - : backend_(backend), device_ordinal_(device_ordinal) {} - +XlaDeviceAllocator::XlaDeviceAllocator() {} XlaDeviceAllocator::~XlaDeviceAllocator() = default; string XlaDeviceAllocator::Name() { return "xla"; } void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { - se::DeviceMemoryBase dmem = - backend_->memory_allocator() - ->Allocate(device_ordinal_, num_bytes, /*retry_on_failure=*/false) - .ValueOrDie(); - VLOG(2) << "Allocated XLA device tensor " << dmem.opaque() << "(" << num_bytes - << ")"; - return dmem.opaque(); + // We always return an empty XlaTensor object, encoded as an opaque tagged + // pointer. We can return an empty object and ignore num_bytes here because we + // have control over all of the uses of this device tensor, and can lazily + // allocate memory when used. This allows us to also know the shape of the + // allocated Tensor, which is useful if the device's tensor representation + // differs from the host. + return XlaTensor::ToOpaquePointer(new XlaTensor()); } void XlaDeviceAllocator::DeallocateRaw(void* ptr) { - se::DeviceMemoryBase dmem(ptr); - TF_CHECK_OK(backend_->memory_allocator()->Deallocate(device_ordinal_, &dmem)); - VLOG(2) << "Deallocated XLA device tensor " << ptr; + delete XlaTensor::FromOpaquePointer(ptr); } void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream) : stream_(stream) {} +XlaTransferManager::XlaTransferManager(se::Stream* stream, + xla::LocalClient* client, + bool transfer_as_literal) + : stream_(stream), + client_(client), + transfer_manager_(client->backend().transfer_manager()), + transfer_as_literal_(transfer_as_literal) {} + +Status XlaTransferManager::TransferLiteralToDevice( + const Tensor& host_tensor, Tensor* device_tensor) const { + xla::Literal literal; + TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal)); + VLOG(1) << "Transfer to device as literal: " << literal.ToString(); + + const xla::ShapedBuffer& shaped_buffer = + XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, + shaped_buffer); +} + +Status XlaTransferManager::TransferLiteralFromDevice( + Tensor* host_tensor, const Tensor& device_tensor) const { + const xla::ShapedBuffer& shaped_buffer = + XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + transfer_manager_->TransferLiteralFromDevice( + stream_->parent(), shaped_buffer)); + VLOG(1) << "Transfer from device as literal: " << literal->ToString(); + return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); +} void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, @@ -68,18 +95,35 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); - void* dst_ptr = DMAHelper::base(device_tensor); - se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); + XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); + CHECK(xla_tensor); + if (!xla_tensor->has_shaped_buffer()) { + Status s = xla_tensor->AllocateShapedBuffer( + device_tensor->dtype(), device_tensor->shape(), client_, + stream_->parent()->device_ordinal()); + if (!s.ok()) { + done(s); + return; + } + } + + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; - stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = TransferLiteralToDevice(*cpu_tensor, device_tensor); + } else { + stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } + xla_tensor->set_host_tensor(*cpu_tensor); done(status); return; @@ -103,18 +147,22 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, << device_tensor->NumElements(); const int64 total_bytes = cpu_tensor->TotalBytes(); - void* src_ptr = const_cast(DMAHelper::base(device_tensor)); - se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); + se::DeviceMemoryBase dev_src_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); void* dst_ptr = DMAHelper::base(cpu_tensor); Status status; - stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); - // TODO(hpucha): Make this asynchronous. - Status block_status = stream_->BlockHostUntilDone(); - if (!block_status.ok()) { - status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, - block_status.error_message().c_str()); + if (transfer_as_literal_) { + status = TransferLiteralFromDevice(cpu_tensor, *device_tensor); + } else { + stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); + // TODO(hpucha): Make this asynchronous. + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { + status = xla::InternalError( + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); + } } done(status); @@ -125,7 +173,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream) : manager_(stream) {} +XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, + bool transfer_as_literal) + : manager_(stream, client, transfer_as_literal) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index c4edcd474e48f791af9340c3cd6e4d031407bb68..ad914a1c23b5f2ea7063722f85e027a99fdb68f9 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -26,11 +27,12 @@ limitations under the License. namespace tensorflow { -// The allocator used for Tensors assigned to the XLA device. It uses -// XLA backend's allocator. +// The allocator used for Tensors assigned to the XLA device. The allocator +// ignores the alignment and size of the request and always returns a new, +// empty, XlaTensor. class XlaDeviceAllocator : public Allocator { public: - XlaDeviceAllocator(const xla::Backend* backend, int device_ordinal); + XlaDeviceAllocator(); ~XlaDeviceAllocator() override; string Name() override; @@ -38,18 +40,14 @@ class XlaDeviceAllocator : public Allocator { void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; void GetStats(AllocatorStats* stats) override; - - private: - // Which backend in the client this allocator belongs to. - const xla::Backend* backend_; - // Which hardware device in the client's backend this allocator belongs to. - const int device_ordinal_; }; // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(perftools::gputools::Stream* stream); + explicit XlaTransferManager(perftools::gputools::Stream* stream, + xla::LocalClient* client, + bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -59,9 +57,20 @@ class XlaTransferManager { perftools::gputools::Stream* stream() const { return stream_; } private: + Status TransferLiteralToDevice(const Tensor& host_tensor, + Tensor* device_tensor) const; + Status TransferLiteralFromDevice(Tensor* host_tensor, + const Tensor& device_tensor) const; + // Stream obtained from a Device, used to transfer tensors between // CPU and device. perftools::gputools::Stream* stream_; + // For the underlying memory allocator and XLA's TransferManager. + xla::LocalClient* client_; + // Transfer manager, for marshalling data to and from the device. + xla::TransferManager* transfer_manager_; + // True if we must use XLA's TransferManager for correct device transfers. + bool transfer_as_literal_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -69,7 +78,8 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(perftools::gputools::Stream* stream); + explicit XlaDeviceContext(perftools::gputools::Stream* stream, + xla::LocalClient* client, bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 2326070358d67c0cf30ef17fab5c93862cd8932c..ac60423d959ca44e7d92e2d965cf731287b1f83f 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -34,14 +34,21 @@ class XlaGpuDeviceFactory : public DeviceFactory { Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_GPU_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT); (void)registrations; std::unique_ptr device; - Status status = XlaDevice::Create( - "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, - /*register_device_for_compilation=*/true, &device); + Status status = + XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; @@ -55,8 +62,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 2614deefd8823dcb8f38e9e22ae4e78145d0d96a..9e098c46f422b436c722bb909dc58930ab7c0ef6 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -25,8 +25,8 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; +constexpr std::array kExecAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: @@ -41,10 +41,17 @@ Status XlaInterpreterDeviceFactory::CreateDevices( DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); (void)registrations; + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, - options, name_prefix, /*register_device_for_compilation=*/true, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, + DEVICE_INTERPRETER_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..50b0061d692f2a8c5ea475c0b00c4cb42a1a84e6 --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -0,0 +1,297 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_launch_util.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace gpu = perftools::gputools; + +namespace tensorflow { + +std::map SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables) { + std::map snapshot; + int first_variable = ctx->num_inputs() - num_variables; + for (int i = 0; i < num_variables; ++i) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, first_variable + i); + OptionalTensor& tensor = snapshot[first_variable + i]; + if (LookupResource(ctx, handle, &variable).ok()) { + tf_shared_lock lock(*variable->mu()); + tensor.name = handle.name(); + tensor.present = true; + tensor.value = *variable->tensor(); + } + } + return snapshot; +} + +XlaAllocator::XlaAllocator(const gpu::Platform* platform, Allocator* wrapped) + : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} + +XlaAllocator::~XlaAllocator() {} + +xla::StatusOr XlaAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size); + if (data == nullptr) { + return errors::ResourceExhausted("Out of memory while trying to allocate ", + size, " bytes."); + } else { + return gpu::DeviceMemoryBase(data, size); + } +} + +Status XlaAllocator::Deallocate(int device_ordinal, + gpu::DeviceMemoryBase* mem) { + wrapped_->DeallocateRaw(mem->opaque()); + return Status::OK(); +} + +namespace { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +std::unique_ptr ExtractSubShapedBuffer( + xla::ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator) { + xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer->on_host_shape(), index); + xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape( + shaped_buffer->on_device_shape(), index); + + xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, + shaped_buffer->platform(), + shaped_buffer->device_ordinal()); + + auto& shape_tree = shaped_buffer->buffers(); + auto& sub_shape_tree = sub_shaped_buffer.buffers(); + sub_shape_tree.CopySubtreeFrom(shape_tree, + /*source_base_index=*/{index}, + /*target_base_index=*/{}); + for (auto& index_to_buffer : shape_tree) { + if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) { + index_to_buffer.second = gpu::DeviceMemoryBase(nullptr, 0); + } + } + return xla::ScopedShapedBuffer::MakeScoped(&sub_shaped_buffer, allocator) + .ValueOrDie(); +} +} // namespace + +XlaComputationLaunchContext::XlaComputationLaunchContext( + int64 num_resource_args, xla::LocalClient* client, + xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) + : num_resource_args_(num_resource_args), + client_(client), + xla_allocator_(xla_allocator), + allocate_xla_tensors_(allocate_xla_tensors) {} + +void XlaComputationLaunchContext::PopulateInputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + const std::map& variables) { + // Build xla::ShapedBuffers that point directly to the Tensor buffers. + arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); + arg_buffers_.resize(kernel->xla_input_shapes.size()); + arg_ptrs_ = std::vector(arg_buffers_.size()); + + // Pass remaining parameters. + const Tensor* t; + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int arg_num = kernel->input_mapping[i]; + const xla::Shape& shape = kernel->xla_input_shapes[i]; + if (variables.count(arg_num)) { + t = &(variables.at(arg_num).value); + CHECK(t); + } else { + t = &(ctx->input(arg_num)); + } + + const xla::Shape on_device_shape = + client_->backend().transfer_manager()->HostShapeToDeviceShape(shape); + if (xla::ShapeUtil::IsTuple(on_device_shape)) { + const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); + arg_ptrs_[i] = + const_cast(&xla_tensor->shaped_buffer()); + } else { + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + gpu::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); + arg_buffers_[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, + client_->platform(), client_->default_device_ordinal()); + arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); + arg_ptrs_[i] = arg_buffers_[i].get(); + } + } +} + +void XlaComputationLaunchContext::PopulateOutputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + std::unique_ptr output) { + gpu::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + // Computation output should always be a tuple. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape (on device): " + << output->on_device_shape().DebugString(); + } + CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + + // Copy XLA results to the OpOutputList. + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); + if (kernel->outputs[i].is_constant) { + // Output is a constant. + const Tensor& const_tensor = kernel->outputs[i].constant_value; + Tensor* output_tensor; + const size_t total_bytes = const_tensor.TotalBytes(); + if (stream && total_bytes > 0) { + // Copy host -> device. (Empty tensors don't have backing buffers.) + // Manually allocate memory using an XlaTensorBuffer so we can allocate + // as much memory as the device requires (as given by + // GetByteSizeRequirement). This avoids XlaTransferManager having to + // reallocate the device buffer later. + VLOG(1) << "Constant output tensor on device"; + + OP_REQUIRES_OK( + ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( + const_tensor.dtype(), const_tensor.shape(), + client_, stream->parent()->device_ordinal())); + } + + Device* device = dynamic_cast(ctx->device()); + OP_REQUIRES(ctx, device != nullptr, + errors::Internal("DeviceBase was not a Device.")); + ctx->op_device_context()->CopyCPUTensorToDevice( + &const_tensor, device, output_tensor, + [&](Status status) { TF_CHECK_OK(status); }); + + if (device->device_type() == DEVICE_GPU) { + // The GPUDeviceContext enqueues the host->device transfer in a + // separate stream from the main compute stream. We must ensure the + // compute stream is synchronized with the host->device transfer + // stream now otherwise we will create a race condition. + auto* gpu_device_context = + static_cast(ctx->op_device_context()); + gpu_device_context->stream()->ThenWaitFor( + gpu_device_context->host_to_device_stream()); + } + } else { + // No copy required. + ctx->set_output(i, const_tensor); + output_tensor = ctx->mutable_output(i); + } + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + xla_tensor->set_host_tensor(const_tensor); + } + } else { + const TensorShape& shape = kernel->outputs[i].shape; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); + + gpu::DeviceMemoryBase buffer = output->buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer( + ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_)); + } else { + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num}); + ctx->set_output(i, output_tensor); + } + ++output_num; + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << ctx->mutable_output(i)->DebugString(); + } + } + + // Apply variable updates, if any. + VLOG(2) << "Applying variable updates"; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + Allocator* allocator = ctx->device()->GetAllocator({}); + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; + OP_REQUIRES(ctx, + write.input_index >= 0 && write.input_index < ctx->num_inputs(), + errors::Internal("Invalid input index for variable write.")); + + gpu::DeviceMemoryBase buffer = output->buffer({output_num}); + + Var* variable = nullptr; + // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, + // not a Tensor. + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, write.input_index), + &variable, [this, ctx, &write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + + core::ScopedUnref s(variable); + + mutex_lock ml(*variable->mu()); + OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, + errors::Internal("Mismatched type in variable write")); + + if (allocate_xla_tensors_) { + Tensor output_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + CHECK(xla_tensor); + xla_tensor->set_shaped_buffer( + ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_)); + *variable->tensor() = output_tensor; + } else { + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + write.type, write.shape, buffer, allocator); + output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num}); + *variable->tensor() = output_tensor; + } + ++output_num; + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h new file mode 100644 index 0000000000000000000000000000000000000000..14f70fe35891040ff3460567adb223be0f1c910f --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -0,0 +1,148 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for launching compiled XLA kernels for a KernelContext. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class XlaAllocator; + +// Takes a snapshot of the values of resource variable arguments, which are +// the last `num_variables` arguments. We snapshot tensors that back +// resource variables since concurrent updates may modify the shape, and it is +// important that the shapes used for compilation match the true shapes of the +// buffers. +// +// Returns a map of TensorFlow argument index to resource variable. +std::map SnapshotResourceVariables(OpKernelContext* ctx, + int num_variables); + +// Adapter class that wraps a Tensorflow allocator as an XLA allocator. +// Assumes that the Tensorflow allocator permits asynchronous deallocation: +// see comment on `AllowsAsynchronousDeallocation()`. +class XlaAllocator : public xla::DeviceMemoryAllocator { + public: + XlaAllocator(const perftools::gputools::Platform* platform, + Allocator* wrapped); + ~XlaAllocator() override; + xla::StatusOr Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) override; + Status Deallocate(int device_ordinal, + perftools::gputools::DeviceMemoryBase* mem) override; + + // The Tensorflow BFC allocator used on GPU allows host-side deallocation + // before GPU execution takes place. Tensorflow uses the ordering of the main + // compute stream to enforce a happens-before relationship between a memory + // allocation and code that reuses the same memory. If Tensorflow adds + // support for multiple GPU streams or allocators with different ordering + // requirements, this code may need to change. + // (This attribute has no effect on CPU.) + bool AllowsAsynchronousDeallocation() const override { return true; } + + private: + Allocator* wrapped_; +}; + +// Helper class to perform the marshalling of TensorFlow inputs and outputs to +// ShapedBuffers suitable for passing to an XLA computation. +class XlaComputationLaunchContext { + public: + // Create a new launch context. 'allocate_xla_tensors' is true if allocated + // output tensors and variables are always XlaTensors. If false they are + // assumed to be "normal" device pointers. + XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors); + + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). + // `variables` is a map from TensorFlow argument number to resource variable. + void PopulateInputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + const std::map& variables); + + // Given the XLA output in `output`, populate all outputs of `ctx`. + void PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + std::unique_ptr output); + + // Return the argument list. Only valid after PopulateInputs() has been + // called. + const std::vector& arguments() const { return arg_ptrs_; } + + private: + int64 num_resource_args_; + xla::LocalClient* client_; + xla::DeviceMemoryAllocator* xla_allocator_; + bool allocate_xla_tensors_; + std::vector> arg_buffers_; + std::vector arg_ptrs_; +}; + +// A simple TensorBuffer implementation that allows us to create Tensors that +// take ownership of pre-allocated memory. +class XlaTensorBuffer : public TensorBuffer { + public: + XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, + Allocator* allocator) + : expected_size_(expected_size), + actual_size_(actual_size), + allocator_(allocator) { + data_ = const_cast(ptr); + } + + ~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); } + + void* data() const override { return data_; } + size_t size() const override { return expected_size_; } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_allocated_bytes(actual_size_); + } + + static Tensor MakeTensor(DataType dtype, const TensorShape& shape, + perftools::gputools::DeviceMemoryBase buffer, + Allocator* allocator) { + size_t expected_size = shape.num_elements() * DataTypeSize(dtype); + auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size, + buffer.size(), allocator); + Tensor t(dtype, shape, tensor_buffer); + tensor_buffer->Unref(); + return t; + } + + private: + void* data_; + size_t expected_size_; + size_t actual_size_; + Allocator* allocator_; +}; + +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc new file mode 100644 index 0000000000000000000000000000000000000000..956328e6757f4c903e3995a54635682d19052794 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" + +namespace tensorflow { + +/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) { + if (tensor->NumElements() == 0) { + return nullptr; + } + XlaTensor* xla_tensor = + FromOpaquePointer(const_cast(tensor->tensor_data().data())); + return xla_tensor; +} + +/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { + return FromTensor(const_cast(tensor)); +} + +/*static*/ perftools::gputools::DeviceMemoryBase +XlaTensor::DeviceMemoryFromTensor(const Tensor& tensor) { + const XlaTensor* xla_tensor = FromTensor(&tensor); + if (xla_tensor) { + CHECK(xla_tensor->has_shaped_buffer()); + return xla_tensor->shaped_buffer().root_buffer(); + } else { + return perftools::gputools::DeviceMemoryBase( + const_cast(tensor.tensor_data().data()), + tensor.tensor_data().size()); + } +} + +Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + xla::LocalClient* client, + int device_ordinal) { + xla::Shape on_host_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &on_host_shape)); + xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape( + on_host_shape); + + xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(), + device_ordinal); + for (auto& index_to_buffer : buffer.buffers()) { + xla::Shape subshape = + xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); + uint64 size = + client->backend().transfer_manager()->GetByteSizeRequirement(subshape); + TF_ASSIGN_OR_RETURN(index_to_buffer.second, + client->backend().memory_allocator()->Allocate( + device_ordinal, size, /*retry_on_failure=*/false)); + } + + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + xla::ScopedShapedBuffer::MakeScoped( + &buffer, client->backend().memory_allocator())); + set_shaped_buffer(std::move(scoped_buffer)); + return Status::OK(); +} + +// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from +// device-side tensors, which are either CPU or GPU memory pointers. This works +// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits. +namespace { +constexpr uintptr_t kTag = 0x1ULL; +} + +/*static*/ XlaTensor* XlaTensor::FromOpaquePointer(void* ptr) { + uintptr_t value = reinterpret_cast(ptr); + if (value & kTag) { + return reinterpret_cast(value & ~kTag); + } else { + return nullptr; + } +} + +/*static*/ void* XlaTensor::ToOpaquePointer(XlaTensor* tensor) { + uintptr_t value = reinterpret_cast(tensor); + CHECK_EQ(value & kTag, 0); + value |= kTag; + return reinterpret_cast(value); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff2fb08f03548260215c6aeded2c124f8d28f43 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ + +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The implementation of a Tensor for an XlaDevice. All device tensors are +// actually one of these. +// +// To distinguish between "normal" device tensors and XlaTensors, the raw +// pointer data stored in the TensorBuffer is a tagged pointer. +class XlaTensor { + public: + // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast + // fails. + static XlaTensor* FromTensor(Tensor* tensor); + // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast + // fails. + static const XlaTensor* FromTensor(const Tensor* tensor); + + // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in + // which case the returned value is shaped_buffer()->root_buffer(), or a + // normal Tensor in which case the returned value is + // {tensor.tensor_data().data(), tensor.tensor_data().size}. + static perftools::gputools::DeviceMemoryBase DeviceMemoryFromTensor( + const Tensor& tensor); + + // Assign the internal ShapedBuffer to new memory for the given dtype and + // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it + // is replaced and the managed memory deallocated. + Status AllocateShapedBuffer(DataType dtype, const TensorShape& shape, + xla::LocalClient* client, int device_ordinal); + + // Some Tensors can have complex on-device shapes, including tuple shapes. To + // manage the memory for these tensors a ShapedBuffer may be required. + + // Return true if this TensorInfo contains a ShapedBuffer. + bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } + // Return the contained ShapedBuffer. + // REQUIRES: has_shaped_buffer() + const xla::ShapedBuffer& shaped_buffer() const { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + // Mutates the TensorInfo to set the ShapedBuffer. + void set_shaped_buffer( + std::unique_ptr shaped_buffer) { + shaped_buffer_ = std::move(shaped_buffer); + } + + // Some tensors on the device may have known values on the host. We use these + // in on-demand mode to avoid re-copying values from the device if we know the + // host value already. + + // Return true if this TensorInfo contains a host tensor. + bool has_host_tensor() const { return host_tensor_ != nullptr; } + // Return the contained host tensor. + // REQUIRES: has_host_tensor() + const Tensor& host_tensor() const { return *host_tensor_; } + // Sets the contained host tensor. + void set_host_tensor(const Tensor& tensor) { + host_tensor_.reset(new Tensor(tensor)); + } + + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. + static XlaTensor* FromOpaquePointer(void* ptr); + // Convert to a raw pointer from an XlaTensor, adding the pointer tag. + static void* ToOpaquePointer(XlaTensor* tensor); + + private: + // The optional contained ShapedBuffer. + std::unique_ptr shaped_buffer_; + // An optional host tensor value. + std::unique_ptr host_tensor_; +}; + +} // namespace tensorflow + +#endif diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index da4bc44c7a75c9f8faf16c537a17a1f2d16d5d61..238fd15166c0b08ee109d6a3888e16c39f87a603 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -49,17 +49,3 @@ cc_library( "//tensorflow/compiler/jit:xla_device", ], ) - -#----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 782bf82d4149968d5e5fbfb93bbd4ff1dcd75494..b9e42ca677cd82e2c18309d25ab33954206ebbe4 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -86,7 +86,10 @@ tf_xla_py_test( # ArgMax needs CustomCall on CPU, which is not available in normal # (not precompiled) TensorFlow. The flag below excludes the CPU # backend. - disabled_backends = "cpu", + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -98,7 +101,7 @@ tf_xla_py_test( tf_xla_py_test( name = "binary_ops_test", - size = "small", + size = "medium", srcs = ["binary_ops_test.py"], shard_count = 5, tags = [ @@ -121,6 +124,7 @@ tf_xla_py_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:framework_for_generated_wrappers", @@ -188,6 +192,31 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "oom_test", + size = "medium", + srcs = ["oom_test.py"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], + tags = [ + # Allocates very large amounts of memory and does not work under TSAN. + "notsan", + "optonly", # Times out frequently in fastbuild. + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:array_ops_gen", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradient_checker", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "conv2d_test", size = "medium", @@ -242,6 +271,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "dynamic_slice_ops_test", + size = "small", + srcs = ["dynamic_slice_ops_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -315,6 +356,8 @@ tf_xla_py_test( name = "function_test", size = "small", srcs = ["function_test.py"], + # Functions are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -466,6 +509,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reduce_window_test", + size = "small", + srcs = ["reduce_window_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "reverse_ops_test", size = "medium", @@ -550,6 +609,8 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + # Stack ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -576,6 +637,8 @@ tf_xla_py_test( name = "tensor_array_ops_test", size = "small", srcs = ["tensor_array_ops_test.py"], + # TensorArray ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -654,6 +717,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "while_test", + size = "small", + srcs = ["while_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "gather_test", size = "medium", @@ -826,17 +904,3 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 30a6d3a74d64f90ad33062df6d1e16e3a575bd63..1e4dd32916c3a40282735fb8f75670b0e9ef0dc9 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -71,7 +71,7 @@ class BinaryOpsTest(XLATestCase): expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([3, 3, -1.5, -8, 44], dtype=dtype), np.array([2, -2, 7, -4, 0], dtype=dtype), expected=np.array( @@ -108,57 +108,57 @@ class BinaryOpsTest(XLATestCase): [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( - gen_math_ops._reciprocal_grad, + gen_math_ops.reciprocal_grad, np.array([4, -3, -2, 1], dtype=dtype), np.array([5, -6, 7, -8], dtype=dtype), expected=np.array([-80, 54, -28, 8], dtype=dtype)) self._testBinary( - gen_math_ops._sigmoid_grad, + gen_math_ops.sigmoid_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-60, -36, -14, 0], dtype=dtype)) self._testBinary( - gen_math_ops._rsqrt_grad, + gen_math_ops.rsqrt_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-160, -81, -28, -4], dtype=dtype)) self._testBinary( - gen_math_ops._sqrt_grad, + gen_math_ops.sqrt_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([0.625, 1, 1.75, 4], dtype=dtype)) self._testBinary( - gen_nn_ops._softplus_grad, + gen_nn_ops.softplus_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array( [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype)) self._testBinary( - gen_nn_ops._softsign_grad, + gen_nn_ops.softsign_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array( [0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype)) self._testBinary( - gen_math_ops._tanh_grad, + gen_math_ops.tanh_grad, np.array([4, 3, 2, 1], dtype=dtype), np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-75, -48, -21, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._elu_grad, + gen_nn_ops.elu_grad, np.array([1, 2, 3, 4, 5, 6], dtype=dtype), np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) self._testBinary( - gen_nn_ops._selu_grad, + gen_nn_ops.selu_grad, np.array([1, 2, 3, 4, 5, 6], dtype=dtype), np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), expected=np.array( @@ -166,20 +166,20 @@ class BinaryOpsTest(XLATestCase): 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) self._testBinary( - gen_nn_ops._relu_grad, + gen_nn_ops.relu_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype)) self._testBinary( - gen_nn_ops._relu6_grad, + gen_nn_ops.relu6_grad, np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype), np.array( [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) self._testBinary( - gen_nn_ops._softmax_cross_entropy_with_logits, + gen_nn_ops.softmax_cross_entropy_with_logits, np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype), expected=[ @@ -190,24 +190,29 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) - self._testBinary( - gen_nn_ops._sparse_softmax_cross_entropy_with_logits, - np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], - [0.9, 1.0, 1.1, 1.2]], dtype=dtype), - np.array([2, 1, 7], dtype=np.int32), - expected=[ - np.array([1.342536, 1.442536, np.nan], dtype=dtype), - np.array([[0.213838, 0.236328, -0.738817, 0.288651], - [0.213838, -0.763672, 0.261183, 0.288651], - [np.nan, np.nan, np.nan, np.nan]], - dtype=dtype), - ], - equality_test=self.ListsAreClose) + # TODO(b/68813416): Fails with bfloat16. + if dtype != dtypes.bfloat16.as_numpy_dtype: + self._testBinary( + gen_nn_ops.sparse_softmax_cross_entropy_with_logits, + np.array( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2]], + dtype=dtype), + np.array([2, 1, 7], dtype=np.int32), + expected=[ + np.array([1.342536, 1.442536, np.nan], dtype=dtype), + np.array( + [[0.213838, 0.236328, -0.738817, 0.288651], [ + 0.213838, -0.763672, 0.261183, 0.288651 + ], [np.nan, np.nan, np.nan, np.nan]], + dtype=dtype), + ], + equality_test=self.ListsAreClose) def testIntOps(self): for dtype in self.int_types: self._testBinary( - gen_math_ops._truncate_div, + gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) @@ -232,11 +237,16 @@ class BinaryOpsTest(XLATestCase): expected=np.right_shift(lhs, rhs)) if dtype in [np.int8, np.int16, np.int32, np.int64]: - lhs = np.array([-1, -5, -3, -14], dtype=dtype) - rhs = np.array([5, 0, 1, 11], dtype=dtype) - self._testBinary( - bitwise_ops.right_shift, lhs, rhs, - expected=np.right_shift(lhs, rhs)) + lhs = np.array([-1, -5, -3, -14, -2], dtype=dtype) + rhs = np.array([5, 0, 1, 11, 36], dtype=dtype) + # HLO has saturating shift behavior. + bits = np.ceil( + np.log(np.iinfo(dtype).max - np.iinfo(dtype).min) / np.log(2)) + expected = [ + np.right_shift(l, r) if r < bits else np.sign(l) + for l, r in zip(lhs, rhs) + ] + self._testBinary(bitwise_ops.right_shift, lhs, rhs, expected=expected) def testNumericOps(self): for dtype in self.numeric_types: @@ -258,9 +268,9 @@ class BinaryOpsTest(XLATestCase): self._testBinary( math_ops.subtract, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([-9, -18], dtype=dtype)) + np.array([1, 2, 100], dtype=dtype), + np.array([10, 20, -1], dtype=dtype), + expected=np.array([-9, -18, 101], dtype=dtype)) self._testBinary( math_ops.subtract, dtype(5), @@ -350,6 +360,14 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + if np.int64 in self.numeric_types: + self._testBinary( + math_ops.add, + np.array([0xffffffff, 0xfffffffff, 1, 1], dtype=np.int64), + np.array([1, 1, 0xffffffff, 0xfffffffff], dtype=np.int64), + expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36], + dtype=np.int64)) + def testComplexOps(self): for dtype in self.complex_types: ctypes = {np.complex64: np.float32} @@ -369,7 +387,7 @@ class BinaryOpsTest(XLATestCase): expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), expected=np.array( @@ -378,7 +396,7 @@ class BinaryOpsTest(XLATestCase): # Test inf/nan scenarios. self._testBinary( - gen_math_ops._real_div, + gen_math_ops.real_div, np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( @@ -418,19 +436,19 @@ class BinaryOpsTest(XLATestCase): lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) self._testBinary( - gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) + gen_math_ops.reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) self._testBinary( - gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) + gen_math_ops.sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + gen_math_ops.rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( - gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) + gen_math_ops.sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) self._testBinary( - gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) + gen_math_ops.tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) def testComplexMath(self): for dtype in self.complex_types: @@ -538,7 +556,7 @@ class BinaryOpsTest(XLATestCase): if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( - gen_math_ops._floor_div, + gen_math_ops.floor_div, np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) @@ -554,12 +572,12 @@ class BinaryOpsTest(XLATestCase): def _testRemainder(self, dtype): """Test cases for remainder operators.""" self._testBinary( - gen_math_ops._floor_mod, + gen_math_ops.floor_mod, np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, -1, 6, 0], dtype=dtype)) self._testBinary( - gen_math_ops._truncate_mod, + gen_math_ops.truncate_mod, np.array([3, 3, -1, -8], dtype=dtype), np.array([2, -2, 7, -4], dtype=dtype), expected=np.array([1, 1, -1, 0], dtype=dtype)) @@ -668,6 +686,11 @@ class BinaryOpsTest(XLATestCase): np.array([[10], [7], [2]], dtype=np.float32), np.float32(7), expected=np.array([[False], [False], [True]], dtype=np.bool)) + self._testBinary( + less_op, + np.array([[10], [7], [2], [-1]], dtype=np.int64), + np.int64(7), + expected=np.array([[False], [False], [True], [True]], dtype=np.bool)) for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: self._testBinary( @@ -686,6 +709,80 @@ class BinaryOpsTest(XLATestCase): np.float32(7), expected=np.array([[False], [True], [True]], dtype=np.bool)) + def testS64Comparisons(self): + for op in [(lambda x, y: x < y), (lambda x, y: x <= y), + (lambda x, y: x >= y), (lambda x, y: x > y)]: + lhs = np.array( + [ + np.int64(0x000000007FFFFFFF), + np.int64(0x000000007FFFFFFF), + np.int64(0x0000000080000000), + np.int64(0x0000000080000000), + np.int64(0x0000000080000001), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFF), + np.int64(0x0000000100000000), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(0x0000000200000002), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x7ffffffefff00010), + np.int64(0x7ffffffefff00010), + np.int64(-1), + np.int64(-1) + ], + dtype=np.int64) + rhs = np.array( + [ + np.int64(0x000000007FFFFFFE), + np.int64(0x000000007FFFFFFF), + np.int64(0x000000007FFFFFFF), + np.int64(0x0000000080000000), + np.int64(0x0000000080000001), + np.int64(0x00000000FFFF0000), + np.int64(0x00000000FFFF0001), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(0x00000000FFFFFFFF), + np.int64(0x0000000100000001), + np.int64(0x0000000100000002), + np.int64(0x0000000100000003), + np.int64(0x0000000200000001), + np.int64(0x0000000200000002), + np.int64(0x0000000200000003), + np.int64(0x0000000300000001), + np.int64(0x0000000300000002), + np.int64(0x0000000300000003), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x00000000FFFFFFFE), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000002), + np.int64(-0x7FFFFFFF00000001), + np.int64(0x00000000FFFFFFFF), + np.int64(-0x7FFFFFFF00000001), + np.int64(-2), + np.int64(-1) + ], + dtype=np.int64) + expected = np.array([op(l, r) for l, r in zip(lhs, rhs)], dtype=np.bool) + self._testBinary(op, lhs, rhs, expected=expected) + def testBroadcasting(self): """Tests broadcasting behavior of an operator.""" @@ -1045,6 +1142,20 @@ class BinaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) + def splitvOp(x, y): # pylint: disable=invalid-name + return array_ops.split(value=y, num_or_size_splits=[2, 3], axis=x) + for axis in [1, -1]: + self._testBinary( + splitvOp, + np.int32(axis), + np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + dtype=dtype), + expected=[ + np.array([[0, 1], [5, 6]], dtype=dtype), + np.array([[2, 3, 4], [7, 8, 9]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + def testTile(self): for dtype in self.numeric_types: self._testBinary( diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 0528a5415d579a844e68403ace1bb8982a10a841..7b114d4f85d3a5cadc6af25b55c5a21f90d2a768 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -51,12 +51,12 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" ] elif backend == "gpu": backend_args += [ "--test_device=XLA_GPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16" ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: @@ -89,4 +89,3 @@ def generate_backend_suites(backends=[]): backends = all_backends() for backend in backends: native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend]) - diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index 5010fe5e21d0782e68d4e6d5bf6b4df1b44793a3..1a8989d7c2f617525c301f30fd899a01362310bf 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -34,6 +34,13 @@ from tensorflow.python.platform import test class CholeskyOpTest(XLATestCase): + # Cholesky defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/cholesky) + @property + def float_types(self): + return set(super(CholeskyOpTest, self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) self.assertAllClose(x, verification_np, atol=atol) diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index 81734082d9aab86f8bc763681265ef64ef32bd31..f10973e19f1945515b776cf86349445ed7334629 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -301,7 +301,7 @@ class ConcatOffsetTest(XLATestCase): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32) - off = gen_array_ops._concat_offset(cdim, [s0, s1, s2]) + off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) ans = sess.run(off) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6a46d2ec3e7aee3a4ecfbf1ab9f622d8eb659e3c --- /dev/null +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -0,0 +1,93 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA dynamic slicing ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class DynamicUpdateSliceOpsTest(XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + self.assertAllClose(result, expected, rtol=1e-3) + + def testUpdateSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array([], dtype=dtype), + np.array([], dtype=dtype), + np.array([0], dtype=np.int32) + ], + expected=np.array([], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), + np.array([-1, -2, -3], dtype=dtype), + np.array([6], dtype=np.int32) + ], + expected=np.array([1, 2, 3, 4, 5, 6, -1, -2, -3, 10], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype), + np.array([[42, 43], [44, 45]], dtype=dtype), + np.array([1, 2], dtype=np.int32) + ], + expected=np.array( + [[1, 2, 3, 4], [5, 6, 42, 43], [9, 10, 44, 45]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype), + np.array([[], []], dtype=dtype), + np.array([1, 2], dtype=np.int32) + ], + expected=np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + xla.dynamic_update_slice, [ + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=dtype), + np.ones([3, 4], dtype=dtype), + np.array([0, 0], dtype=np.int32) + ], + expected=np.ones([3, 4], dtype=dtype)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index f9db4cf2017c0b4b6dc0cfeeda6dca7bb9d14f19..8e6407dffdac3adbcda8cbca2109ef9196defa8c 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -134,9 +134,15 @@ class FtrlOptimizerTest(XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-2.60260963, -4.29698515]), var0.eval(), float_rtol=1e-5) + np.array([-2.60260963, -4.29698515]), + var0.eval(), + float_rtol=1e-5, + half_rtol=1e-2) self.assertAllCloseAccordingToType( - np.array([-0.28432083, -0.56694895]), var1.eval(), float_rtol=1e-5) + np.array([-0.28432083, -0.56694895]), + var1.eval(), + float_rtol=1e-5, + half_rtol=1e-2) def testFtrlwithoutRegularization2(self): for dtype in self.float_types: @@ -272,8 +278,8 @@ class FtrlOptimizerTest(XLATestCase): with self.test_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) - self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4) - self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4, half_rtol=1e-2) def testEquivGradientDescentwithoutRegularization(self): steps = 5 diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 11d8a99ffe1a136a54b16e20f1792062203f7969..fbc3c994d163a504351fcccd1ba71a0997e6516f 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -105,6 +105,28 @@ class FunctionTest(XLATestCase): result = sess.run(call_f) self.assertAllClose(result, expected, rtol=1e-3) + def testCompileTimeConstantsInDefun(self): + """Tests that XLA handles compile-time constants in defuns.""" + with self.test_session() as sess: + + @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) + def Foo(a, c, d): + # c and d must be known at compile time + x = array_ops.slice(a, c, d) + return x + + a = array_ops.placeholder(dtypes.float32) + c = array_ops.placeholder(dtypes.int32, shape=[4]) + d = array_ops.placeholder(dtypes.int32, shape=[4]) + with self.test_scope(): + call_f = Foo(a, c, d) + result = sess.run(call_f, feed_dict={ + a: np.ones([1, 4, 4, 1]), + c: [0, 0, 0, 0], + d: [1, 2, 2, 1]}) + + self.assertAllEqual(np.ones([1, 2, 2, 1]), result) + # TODO(b/36139787): Re-enable this test when noinline works again. def DISABLED_testFunctionsNoInline(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 538fa8e8e570b83ed681ecc0501285520cabdecb..12791ef8ac1da948608b1585f423ca217378f031 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -65,7 +65,8 @@ class RGBToHSVTest(XLATestCase): # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) self.assertAllClose(batch2, join2) - self.assertAllCloseAccordingToType(batch2, inp, bfloat16_atol=0.03) + self.assertAllCloseAccordingToType( + batch2, inp, bfloat16_atol=0.03, half_rtol=0.02) def testRGBToHSVRoundTrip(self): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] @@ -426,7 +427,7 @@ class ResizeBilinearTest(XLATestCase): with self.test_session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) - resized = gen_image_ops._resize_bilinear_grad( + resized = gen_image_ops.resize_bilinear_grad( grads, np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), align_corners=True) diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 2d8236e2cbdfafb35626cd582ee39b1f917aec7f..1f7da659e5590b86c96964bbd14a4175341783c8 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -37,6 +39,18 @@ from tensorflow.python.platform import test jit_scope = jit.experimental_jit_scope +# Disable rewrites to make sure we don't end up having to update this test +# whenever we implement new ones. +def NoRewriteSessionConfig(): + rewriter_config = rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + function_optimization=rewriter_config_pb2.RewriterConfig.OFF) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) + return config_pb2.ConfigProto(graph_options=graph_options) + + def CompiledKernel(fn, *inputs, **kwargs): """Execute 'fn' as a compiled XLA kernel, with 'inputs'.""" name = kwargs.pop("name", None) @@ -80,7 +94,7 @@ class JitLaunchTest(test.TestCase): # actually ran. However, it is sometimes possible for _XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): - with session_lib.Session() as sess: + with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] feeds = {} for arg in args: @@ -257,7 +271,7 @@ class XlaCompilationTest(test.TestCase): def testReshape(self): """Tests an operator with compile-time constant and non-constant inputs.""" - with self.test_session() as sess: + with self.test_session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -281,7 +295,7 @@ class XlaCompilationTest(test.TestCase): def testIgnoredArguments(self): """Tests that JIT computations can ignore formal parameters.""" - with self.test_session() as sess: + with self.test_session(config=NoRewriteSessionConfig()) as sess: x = array_ops.placeholder(dtypes.int32) y = array_ops.placeholder(dtypes.int32) with jit_scope(): @@ -305,7 +319,7 @@ class XlaCompilationTest(test.TestCase): def testLoops(self): """Tests that compilation accepts computations containing loops.""" - with self.test_session() as session: + with self.test_session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): c = lambda i, _: math_ops.less(i, 5) @@ -323,7 +337,7 @@ class XlaCompilationTest(test.TestCase): def testCond(self): """Tests that compilation handles switch operators.""" - with self.test_session() as session: + with self.test_session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) y = array_ops.placeholder(dtypes.float32) c = array_ops.placeholder(dtypes.bool) @@ -364,7 +378,8 @@ class XlaCompilationTest(test.TestCase): inp = array_ops.placeholder(dtypes.float32) out = Entry(inp) - with self.test_session(graph=g, use_gpu=True) as sess: + with self.test_session( + config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: run_metadata = config_pb2.RunMetadata() val = sess.run(out, feed_dict={inp: [2., 10.]}, @@ -376,7 +391,7 @@ class XlaCompilationTest(test.TestCase): def testLoopDeadlock(self): """Regression test for bug that caused deadlocks in graphs with loops.""" - with self.test_session() as session: + with self.test_session(config=NoRewriteSessionConfig()) as session: x = array_ops.placeholder(dtypes.float32) with jit_scope(): y = x + 1.0 @@ -403,10 +418,10 @@ class XlaCompilationTest(test.TestCase): y = Forward(x) dx, = gradients_impl.gradients(y, [x], 1.0) - cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( - optimizer_options=config_pb2.OptimizerOptions( - opt_level=config_pb2.OptimizerOptions.L1, - do_function_inlining=True))) + cfg = NoRewriteSessionConfig() + cfg.graph_options.optimizer_options.opt_level = ( + config_pb2.OptimizerOptions.L1) + cfg.graph_options.optimizer_options.do_function_inlining = True with session_lib.Session(graph=g, config=cfg) as sess: run_metadata = config_pb2.RunMetadata() dx_val = sess.run(dx, @@ -436,5 +451,55 @@ class XlaCompilationTest(test.TestCase): self.assertTrue(InLabels(labels, "_XlaLaunch")) +class ElementWiseFusionTest(test.TestCase): + + # Runs a simple test with the input jit_level and fusion_only flag. + def simpleTest(self, arg0, arg1, global_jit_level): + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = global_jit_level + + with session_lib.Session(config=config) as sess: + a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1") + a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2") + # Two element-wise ops. We need at least two ops since single + # element clusters are not passed to XLA in fusion_only mode. + a3 = a1 * a2 + a4 = a3 + a1 + # A matmul to break XLA clustering. + a5 = math_ops.matmul(a4, a1) + # Two more element-wise ops. + a6 = a5 - a4 + a7 = a6 + a2 + + run_metadata = config_pb2.RunMetadata() + output = sess.run( + a7, { + a1: arg0, + a2: arg1 + }, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = RunMetadataLabels(run_metadata) + count = sum("_XlaLaunch(" in x for x in labels) + + return output, count + + def testElementWiseClustering(self): + arg0 = np.random.rand(2, 2).astype(np.float32) + arg1 = np.random.rand(2, 2).astype(np.float32) + os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true" + tf_op, tf_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.OFF) + self.assertEqual(0, tf_count) + + tfef_op, tfef_count = self.simpleTest(arg0, arg1, + config_pb2.OptimizerOptions.ON_1) + self.assertEqual(2, tfef_count) + + self.assertAllClose(tf_op, tfef_op, rtol=1e-1) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 5d8d89224d4a778d84803811710bb095872e86b2..69bd8f7230d4394c45764d02a88fb0ec097c5756 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -115,11 +115,11 @@ class LRNTest(XLATestCase): out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) with ops.device(CPU_DEVICE): - expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, - depth_radius, bias, alpha, beta) + expected = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) with self.test_scope(): - actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image, - depth_radius, bias, alpha, beta) + actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, + depth_radius, bias, alpha, beta) expected_val = expected.eval() actual_val = actual.eval() self.assertAllClose(actual_val, expected_val, rtol=1e-3) diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index cccb7f5789dce39ef8c3d4b3a7573aaa983b3fbd..5819b2bf2b55b9213a039c0ba82dd0bf1c738b00 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -37,6 +37,14 @@ def MakePlaceholder(x): class MatrixTriangularSolveOpTest(XLATestCase): + # MatrixTriangularSolve defined for float64, float32, complex64, complex128 + # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) + @property + def float_types(self): + return set(super(MatrixTriangularSolveOpTest, + self).float_types).intersection( + (np.float64, np.float32, np.complex64, np.complex128)) + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol): diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1434e965e3d7eaeca94ad0fa97498f884e30e115 --- /dev/null +++ b/tensorflow/compiler/tests/oom_test.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for out-of-memory conditions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class OutOfMemoryTest(xla_test.XLATestCase): + + def testOutputOutOfMemory(self): + """Allocates tensors until out of memory. + + Generates a large rank-1 tensor. The tensor is an output of an XLA + computation, not constant. + + Check that a ResourceExhaustedError is raised and can be caught. + + We spin in a loop generating larger and larger tensors until an OOM event + happens. We may be running sandboxed, so have a small host memory limit, so + any hardcoded value is unlikely to land in the sweet spot between device + memory size and host memory size with stability. + """ + + def test_loop(): + size = 2e8 + while True: + with self.test_session(): + # Force the compiled code to not be constant by feeding in an addend. + p = array_ops.placeholder(dtypes.float32, shape=[]) + with self.test_scope(): + # Create a large R1 tensor. + c = array_ops.zeros([size, 1]) + p + + c.eval(feed_dict={p: 1.0}) + size *= 2 + + self.assertRaises(errors.ResourceExhaustedError, test_loop) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index eb48fe555a0b182ea7983cbd8c3b217d56350408..4eed903963a34a253ea5c409782d9a89a97a4fdf 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import test # MaxPoolGrad. def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding): del outputs # Unused by average-pooling gradients. - return gen_nn_ops._avg_pool3d_grad( + return gen_nn_ops.avg_pool3d_grad( inputs.get_shape().as_list(), output_gradients, ksize=ksize, @@ -263,7 +263,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[1, 3, 3, 3, 1], ksize=[1, 1, 1], strides=[1, 1, 1], @@ -272,7 +272,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_1_6_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 3, 6, 3], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -281,7 +281,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_1_7_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 5, 7, 3], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -290,7 +290,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradValidPadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 2, 2, 2, 3], ksize=[2, 2, 2], strides=[2, 2, 2], @@ -299,7 +299,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding1_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 2, 4, 1], ksize=[1, 1, 1], strides=[1, 1, 1], @@ -308,7 +308,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding2_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 3, 2, 4, 1], ksize=[2, 2, 2], strides=[1, 1, 1], @@ -317,7 +317,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding2_2_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[2, 5, 2, 4, 3], ksize=[2, 2, 2], strides=[2, 2, 2], @@ -326,7 +326,7 @@ class Pooling3DTest(XLATestCase): def testMaxPoolGradSamePadding3_1_3d(self): self._VerifyGradient( nn_ops.max_pool3d, - gen_nn_ops._max_pool3d_grad, + gen_nn_ops.max_pool3d_grad, input_sizes=[1, 3, 3, 7, 1], ksize=[3, 3, 3], strides=[1, 1, 1], diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 7c19a99c4eb4be3ca34b3ce949216e557b0a681d..fe270af3d636c0824621f36360ce9e7d14d8fc91 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -292,8 +292,15 @@ class PoolGradTest(XLATestCase): CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" - def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format): + def _VerifyOneTest(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=None): """Verifies the output values of the pooling gradient function. Args: @@ -304,9 +311,19 @@ class PoolGradTest(XLATestCase): strides: The stride dimensions padding: Padding type. data_format: The data format we use to run the pooling operation. + pool_grad_grad_func: Second-order gradient function, if available. """ total_size = np.prod(input_sizes) - x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) + # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally + # maximal at 16 bits. Switch to np.random.randn when resolved. + x = np.arange(1, total_size + 1, dtype=np.float32) + x *= (np.random.randint(2, size=total_size) * 2 - 1) # Flip signs randomly + # Verify some specifically interesting values... + x[np.random.choice(total_size)] = np.inf + x[np.random.choice(total_size)] = -np.inf + # TODO(b/74222344): Fix nan handling for max pool grad. + # x[np.random.choice(total_size)] = np.nan + x = x.reshape(input_sizes) with self.test_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). @@ -323,6 +340,8 @@ class PoolGradTest(XLATestCase): output_gradient_vals = np.arange( 1, output_vals.size + 1, dtype=np.float32) output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) + output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) + output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) # Use the Tensorflow CPU pooling gradient to compute the expected input # gradients. @@ -342,18 +361,36 @@ class PoolGradTest(XLATestCase): {inputs: x, output_gradients: output_gradient_vals}) + output_grad_gradients = array_ops.placeholder( + dtypes.float32, shape=expected_input_gradient_vals.shape) + if pool_grad_grad_func is not None: + expected_grad_gradients = pool_grad_grad_func( + inputs, + outputs, + output_grad_gradients, + ksize=ksize, + strides=strides, + padding=padding, + data_format="NHWC") + expected_grad_gradients_vals = sess.run(expected_grad_gradients, { + inputs: x, + output_grad_gradients: output_grad_grad_vals + }) + # Run the gradient op on the XLA device with self.test_scope(): outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) xla_inputs = inputs xla_outputs = outputs xla_output_gradients = output_gradients + xla_output_grad_gradients = output_grad_gradients xla_ksize = ksize xla_strides = strides if data_format == "NCHW": xla_inputs = NHWCToNCHW(inputs) xla_outputs = NHWCToNCHW(outputs) xla_output_gradients = NHWCToNCHW(output_gradients) + xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients) xla_ksize = NHWCToNCHW(ksize) xla_strides = NHWCToNCHW(strides) actual_input_gradients = pool_grad_func( @@ -366,22 +403,54 @@ class PoolGradTest(XLATestCase): data_format=data_format) if data_format == "NCHW": actual_input_gradients = NCHWToNHWC(actual_input_gradients) - actual = sess.run(actual_input_gradients, { + if pool_grad_grad_func is not None: + actual_grad_gradients = pool_grad_grad_func( + xla_inputs, + xla_outputs, + xla_output_grad_gradients, + ksize=xla_ksize, + strides=xla_strides, + padding=padding, + data_format=data_format) + if data_format == "NCHW": + actual_grad_gradients = NCHWToNHWC(actual_grad_gradients) + actual_input_gradients_vals = sess.run(actual_input_gradients, { inputs: x, outputs: output_vals, output_gradients: output_gradient_vals }) - # Compare the Tensorflow and XLA results. self.assertAllClose( - expected_input_gradient_vals.flatten(), - actual.flatten(), + expected_input_gradient_vals, + actual_input_gradients_vals, rtol=1e-4, atol=1e-6) - self.assertShapeEqual(actual, inputs) - - def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize, - strides, padding): + self.assertShapeEqual(actual_input_gradients_vals, inputs) + + if pool_grad_grad_func is not None: + actual_grad_gradients_vals = sess.run( + actual_grad_gradients, { + inputs: x, + outputs: output_vals, + output_grad_gradients: output_grad_grad_vals + }) + + # Compare the Tensorflow and XLA results. + self.assertAllClose( + expected_grad_gradients_vals, + actual_grad_gradients_vals, + rtol=1e-4, + atol=1e-6) + self.assertShapeEqual(actual_grad_gradients_vals, outputs) + + def _VerifyValues(self, + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + pool_grad_grad_func=None): """Verifies the output values of the pooling function. Args: @@ -391,12 +460,20 @@ class PoolGradTest(XLATestCase): ksize: The kernel size dimensions strides: The stride dimensions padding: Padding type. + pool_grad_grad_func: Second-order gradient function, if available. """ for data_format in GetTestConfigs(): - self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize, - strides, padding, data_format) - - def _TestPooling(self, forward_op, backward_op): + self._VerifyOneTest( + pool_func, + pool_grad_func, + input_sizes, + ksize, + strides, + padding, + data_format, + pool_grad_grad_func=pool_grad_grad_func) + + def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None): # VALID padding self._VerifyValues( forward_op, @@ -404,7 +481,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 3, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding self._VerifyValues( @@ -413,7 +491,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 3, 3], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, non square window self._VerifyValues( @@ -422,7 +501,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 2, 2, 1], ksize=[1, 1, 2, 1], strides=[1, 1, 1, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # VALID padding, uneven stride self._VerifyValues( @@ -431,14 +511,16 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 1, 2, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) self._VerifyValues( forward_op, backward_op, input_sizes=[1, 4, 4, 1], ksize=[1, 2, 2, 1], strides=[1, 2, 1, 1], - padding="VALID") + padding="VALID", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 4 input self._VerifyValues( @@ -447,7 +529,8 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 4, 4, 4], ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) # SAME padding, size 8 input self._VerifyValues( @@ -456,10 +539,14 @@ class PoolGradTest(XLATestCase): input_sizes=[1, 8, 8, 8], ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], - padding="SAME") + padding="SAME", + pool_grad_grad_func=pool_grad_grad_func) def testMaxPool(self): - self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad) + self._TestPooling( + nn_ops.max_pool, + gen_nn_ops.max_pool_grad, + pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad) def testAvgPool(self): # Wrapper around AvgPoolGrad that ignores extra arguments needed by @@ -467,7 +554,7 @@ class PoolGradTest(XLATestCase): def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding, data_format): del outputs # Unused by average-pooling gradients. - return gen_nn_ops._avg_pool_grad( + return gen_nn_ops.avg_pool_grad( inputs.get_shape().as_list(), output_gradients, ksize=ksize, @@ -483,7 +570,7 @@ class PoolGradTest(XLATestCase): def testMaxPoolKernelSmallerThanStrideValid(self): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 7, 7, 1], ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1], @@ -492,7 +579,7 @@ class PoolGradTest(XLATestCase): def testMaxPoolKernelSmallerThanStrideSame(self): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 3, 3, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], @@ -500,7 +587,7 @@ class PoolGradTest(XLATestCase): self._VerifyValues( nn_ops.max_pool, - gen_nn_ops._max_pool_grad, + gen_nn_ops.max_pool_grad, input_sizes=[1, 4, 4, 1], ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index e72dd4eea9f127e1df96ab166103c4c16372adb6..e53efc3091d8935e745122af29abd7b8063b1d01 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -83,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) { return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } -constexpr std::array kAllXlaTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}}; +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64, DT_INT64}}; // An OpTestBuilder is a graph builder class that takes as input an operator to // test, its inputs and attributes, and builds a graph that executes the diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 965fdf684b973498d0b3c3cde17711cca7279705..2c084b04fa2f67ad0d86508109522d7bead206eb 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -30,8 +31,13 @@ from tensorflow.python.platform import googletest class ReduceOpsTest(XLATestCase): - def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, - rtol=1e-4, atol=1e-4): + def _testReduction(self, + tf_reduce_fn, + np_reduce_fn, + dtype, + test_inputs, + rtol=1e-4, + atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: @@ -41,16 +47,16 @@ class ReduceOpsTest(XLATestCase): index = array_ops.placeholder(dtypes.int32) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=0), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=0), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [1]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=1), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) result = sess.run(out, {a: test_input, index: [-1]}) - self.assertAllClose(result, np_reduce_fn(test_input, axis=1), - rtol=rtol, atol=atol) + self.assertAllClose( + result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Invalid reduction dim'): @@ -60,7 +66,7 @@ class ReduceOpsTest(XLATestCase): errors_impl.InvalidArgumentError, 'Invalid reduction dim'): sess.run(out, {a: test_input, index: [2]}) - FLOAT_DATA = [ + REAL_DATA = [ np.zeros(shape=(2, 0)), np.zeros(shape=(0, 30)), np.arange(1, 7).reshape(2, 3), @@ -74,7 +80,7 @@ class ReduceOpsTest(XLATestCase): np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_REAL_DATA = [x for x in REAL_DATA if np.size(x) > 0] NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), @@ -83,8 +89,7 @@ class ReduceOpsTest(XLATestCase): ] def testReduceSumF32(self): - self._testReduction(math_ops.reduce_sum, np.sum, np.float32, - self.FLOAT_DATA) + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA) def testReduceSumC64(self): self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, @@ -92,7 +97,7 @@ class ReduceOpsTest(XLATestCase): def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, - self.FLOAT_DATA) + self.REAL_DATA) def testReduceProdC64(self): self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, @@ -100,31 +105,44 @@ class ReduceOpsTest(XLATestCase): def testReduceMin(self): - def reference_min(inp, axis): + def reference_min(dtype, inp, axis): """Wrapper around np.amin that returns +infinity for an empty input.""" if inp.shape[axis] == 0: - return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + if np.issubdtype(dtype, np.floating): + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf')) + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + np.iinfo(dtype).max) return np.amin(inp, axis) - self._testReduction(math_ops.reduce_min, reference_min, np.float32, - self.FLOAT_DATA) + for dtype in set(self.all_types).intersection( + [np.float32, np.int32, np.int64]): + self._testReduction(math_ops.reduce_min, + functools.partial(reference_min, dtype), dtype, + self.REAL_DATA) def testReduceMax(self): - def reference_max(inp, axis): + def reference_max(dtype, inp, axis): """Wrapper around np.amax that returns -infinity for an empty input.""" if inp.shape[axis] == 0: - return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf')) + if np.issubdtype(dtype, np.floating): + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + float('-inf')) + return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], + np.iinfo(dtype).min) return np.amax(inp, axis) - self._testReduction(math_ops.reduce_max, reference_max, np.float32, - self.FLOAT_DATA) + for dtype in set(self.all_types).intersection( + [np.float32, np.int32, np.int64]): + self._testReduction(math_ops.reduce_max, + functools.partial(reference_max, dtype), dtype, + self.REAL_DATA) def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, - self.NONEMPTY_FLOAT_DATA) + self.NONEMPTY_REAL_DATA) def testReduceMeanC64(self): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e78a63465b80644d8810d9fa7433653bc4639fed --- /dev/null +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for xla.reduce_window.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class ReduceWindowTest(XLATestCase): + """Test cases for xla.reduce_window.""" + + def _reduce_window(self, operand, init, reducer, **kwargs): + with self.test_session(): + placeholder = array_ops.placeholder(operand.dtype) + with self.test_scope(): + output = xla.reduce_window(placeholder, init, reducer, **kwargs) + return output.eval(feed_dict={placeholder: operand}) + + def testReduceWindow(self): + + # TODO(b/77644762): float16 and float64 ReduceWindow are unimplemented. + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + self.assertAllClose( + np.array([3, 5, 7, 9, 11, 13], dtype=dtype), + self._reduce_window( + np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[2])) + + self.assertAllClose( + np.array([3, 7, 11], dtype=dtype), + self._reduce_window( + np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[2], + window_strides=[2])) + + self.assertAllClose( + np.array([1, 4, 7], dtype=dtype), + self._reduce_window( + np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[1], + window_strides=[3])) + + self.assertAllClose( + np.array([[24, 36, 24], [96, 0, 0]], dtype=dtype), + self._reduce_window( + np.array([[1, 2, 3, 4], [4, 3, 2, 1], [2, 4, 0, 1]], dtype=dtype), + 1.0, + mul_reducer, + window_dimensions=[2, 2], + window_strides=[1, 1])) + + self.assertAllClose( + np.array([[0, 0, 0], [5, 10, 5], [2, 4, 1], [0, 0, 0]], dtype=dtype), + self._reduce_window( + np.array([[1, 2, 3, 4], [4, 3, 2, 1], [2, 4, 0, 1]], dtype=dtype), + 0.0, + sum_reducer, + window_dimensions=[2, 2], + window_strides=[2, 2], + padding=[[2, 3], [1, 2]])) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 23bc39cf3f7087424719edfb8b6ee35d87295534..4a9c0e7471f9cdb2a47b54705495d2dda9748890 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -63,10 +63,10 @@ class SegmentReductionOpsTest(XLATestCase): def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): for dtype in self.numeric_types: self.assertAllClose( - np.array([0, 3, 2, 5], dtype=dtype), + np.array([6, 3, 0, 6], dtype=dtype), self.UnsortedSegmentSum( - np.array([0, 1, 2, 3, 4, 5], dtype=dtype), - np.array([3, -1, 2, 1, -1, 3], dtype=np.int32), 4)) + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): for dtype in self.numeric_types: diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index a7cbfb04003c397212a35e16c6b23d7c2a18f7df..305ca0c6b78d3ef985deb38816f9388e7983906b 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -137,6 +138,34 @@ class StridedSliceTest(XLATestCase): self.assertAllEqual([6, 4], result) + def test2DDegenerate(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + o = array_ops.strided_slice(i, [-1, 0], [0, 3]) + params = { + i: [[0, 1, 2], + [3, 4, 5]] + } + result = o.eval(feed_dict=params) + + self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + + def test2DDegenerateNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) + params = { + i: [[0, 1, 2], + [3, 4, 5]] + } + result = o.eval(feed_dict=params) + + self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + def test3D(self): for dtype in self.numeric_types: with self.test_session(): diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index c013f4b50a4cf95be8028248c52b10b1c3be2bd3..f37c34156f96761632247be4bc1b62fca54f666e 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import test @@ -75,11 +76,11 @@ class SpaceToBatchTest(XLATestCase): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) - x_tf = gen_array_ops._space_to_batch( + x_tf = gen_array_ops.space_to_batch( placeholder, paddings, block_size=block_size) self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) # inputs = batch_to_space(outputs) - x_tf = gen_array_ops._batch_to_space( + x_tf = gen_array_ops.batch_to_space( placeholder, paddings, block_size=block_size) self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) @@ -156,14 +157,32 @@ class SpaceToBatchNDTest(XLATestCase): paddings = np.array(paddings).reshape((len(block_shape), 2)) with self.test_session() as sess, self.test_scope(): for dtype in self.float_types: + # TODO(b/68813416): Skip bfloat16's as the input type for direct is + # float32 and results in a mismatch, while making testDirect provide the + # correctly typed input results in 'no fill-function for data-type' + # error. + if dtype == dtypes.bfloat16.as_numpy_dtype: + continue + if dtype == np.float16: + actual_inputs = np.array(inputs).astype(dtype) + actual_paddings = np.array(paddings).astype(dtype) + expected_outputs = np.array(outputs).astype(dtype) + else: + actual_inputs = inputs + actual_paddings = paddings + expected_outputs = outputs placeholder = array_ops.placeholder(dtype) # outputs = space_to_batch(inputs) - x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings) - self.assertAllEqual(sess.run(x_tf, {placeholder: inputs}), outputs) + x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, + actual_paddings) + self.assertAllEqual( + sess.run(x_tf, {placeholder: actual_inputs}), expected_outputs) # inputs = batch_to_space(outputs) placeholder = array_ops.placeholder(dtype) - x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, paddings) - self.assertAllEqual(sess.run(x_tf, {placeholder: outputs}), inputs) + x_tf = array_ops.batch_to_space_nd(placeholder, block_shape, + actual_paddings) + self.assertAllEqual( + sess.run(x_tf, {placeholder: expected_outputs}), actual_inputs) def _testDirect(self, input_shape, block_shape, paddings): inputs = np.arange(np.prod(input_shape), dtype=np.float32) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index 2b9c2279737ccee531d488d27ccdb0cafa1dc8fc..94342f9567ca71274609e63b0482d55637c98d51 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -34,33 +34,33 @@ class StackOpTest(XLATestCase): with self.test_session(), self.test_scope(): size = array_ops.placeholder(dtypes.int32) v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, v) + h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): with self.test_session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True) + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, x, swap_memory=True) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32) + c1 = gen_data_flow_ops.stack_pop_v2(h, dtypes.float32) self.assertAllClose(a, c1.eval({x: a})) def testMultiStack(self): with self.test_session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, v) + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_push_v2(h1, v) with ops.control_dependencies([c1]): - c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar") - c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0) + c1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="bar") + c2 = gen_data_flow_ops.stack_push_v2(h2, 5.0) with ops.control_dependencies([c2]): - c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + c2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) r = c1 + c2 self.assertAllClose(9.0, r.eval({v: 4.0})) @@ -69,15 +69,15 @@ class StackOpTest(XLATestCase): with self.test_session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) - h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") + h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + h2 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_push_v2(h1, v1) + c1 = gen_data_flow_ops.stack_push_v2(h1, v1) with ops.control_dependencies([c1]): - c2 = gen_data_flow_ops._stack_push_v2(h2, v2) + c2 = gen_data_flow_ops.stack_push_v2(h2, v2) with ops.control_dependencies([c2]): - pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32) - pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32) + pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32) + pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32) out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0}) self.assertAllClose(out1, 4.0) @@ -86,17 +86,17 @@ class StackOpTest(XLATestCase): def testCloseStack(self): with self.test_session() as sess, self.test_scope(): size = array_ops.placeholder(dtypes.int32) - h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo") - c1 = gen_data_flow_ops._stack_close_v2(h) + h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {size: 5}) def testPushCloseStack(self): with self.test_session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) - h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo") - c = gen_data_flow_ops._stack_push_v2(h, v) + h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") + c = gen_data_flow_ops.stack_push_v2(h, v) with ops.control_dependencies([c]): - c1 = gen_data_flow_ops._stack_close_v2(h) + c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {v: [[4.0, 5.0]]}) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index a62925a1818da00cb0a9e82e1281db20fb38b208..7624d6e4b2e2ece6a61155743fc8b866f6903f32 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -338,7 +338,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0 = ta.write(0, [[4.0, 5.0]]) # Test reading wrong datatype. - r0_bad = gen_data_flow_ops._tensor_array_read_v3( + r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtype2, flow_in=w0.flow) with self.assertRaisesOpError("TensorArray dtype is "): r0_bad.eval() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 3d3e112f4821ea8e57ea9589a5b4433647ad294b..ba79f393a8f9b24ac506d2130957c38ecd442509 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -154,6 +154,9 @@ class UnaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018. + if dtype == np.float16 and self.device == "XLA_CPU": + continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( math_ops.acos, @@ -600,6 +603,20 @@ class UnaryOpsTest(XLATestCase): src, expected=dst) + def testBitcast(self): + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int32), + np.array([1, 0x3f800000], np.int32), + expected=np.array([1, 0x3f800000], np.int32)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.float32), + np.array([1, 0x3f800000], np.int32), + expected=np.array([1e-45, 1.0], np.float32)) + self._assertOpOutputMatchesExpected( + lambda x: array_ops.bitcast(x, dtypes.int32), + np.array([1e-45, 1.0], np.float32), + expected=np.array([1, 0x3f800000], np.int32)) + def testInvertPermutation(self): self._assertOpOutputMatchesExpected( array_ops.invert_permutation, @@ -779,7 +796,10 @@ class UnaryOpsTest(XLATestCase): self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) self._assertSoftplusMatchesExpected( [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) - log_eps = np.log(np.finfo(dtype).eps) + if dtype == dtypes.bfloat16.as_numpy_dtype: + log_eps = np.log(np.finfo(np.float32).eps) + else: + log_eps = np.log(np.finfo(dtype).eps) one = dtype(1) ten = dtype(10) self._assertSoftplusMatchesExpected([ diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index b08d6ab21e0746558cb3d4818d4c822c45d2e9ee..8ecad00f6e23b3a7746bbb473102ac847bf4cbfd 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -230,7 +230,10 @@ class SliceAssignTest(XLATestCase): # shrink shape changes checker[1:2, 1] = [66] checker[1, 1:2] = [66] - checker[1, 1] = 66 + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker[1, 1] = 66 # newaxis shape changes checker[:, None, :] = [[[10, 20, 30]], [[40, 50, 50]]] # shrink and newaxis @@ -243,8 +246,11 @@ class SliceAssignTest(XLATestCase): # Assign vector to scalar (rank-0) using newaxis checker2 = StridedSliceAssignChecker(self, 222, dtype=dtype) - checker2[()] = 6 # no indices - checker2[...] = 6 # ellipsis + if dtype != dtypes.bfloat16.as_numpy_dtype: + # TODO(b/68813416): valnp call above results in an ndarray and not a + # number for bfloat16s. + checker2[()] = 6 # no indices + checker2[...] = 6 # ellipsis checker2[None] = [6] # new axis def testUninitialized(self): diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f79eb27435cc954cebde4357c1d946a320f4ed75 --- /dev/null +++ b/tensorflow/compiler/tests/while_test.py @@ -0,0 +1,130 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for while loops in XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class WhileTest(XLATestCase): + + def testSingletonLoopHandrolled(self): + # Define a function for the loop body + @function.Defun(dtypes.int32) + def loop_body(step): + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + return step_out + + # Define a function for the loop condition + @function.Defun(dtypes.int32) + def loop_cond(step): + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) + + result = sess.run(loop_outputs, {init_index: 0}) + self.assertAllClose(result, [10], rtol=1e-3) + + def testCountingLoopHandrolled(self): + # Define a function for the loop body + @function.Defun(dtypes.int32, dtypes.float32) + def loop_body(step, rsum): + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + sum_out = rsum + constant_op.constant(1.5, dtype=dtypes.float32) + return step_out, sum_out + + # Define a function for the loop condition + @function.Defun(dtypes.int32, dtypes.float32) + def loop_cond(step, rsum): + del rsum + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + init_sum = array_ops.placeholder(dtypes.float32, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index, init_sum], loop_cond, + loop_body) + + result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0}) + self.assertAllClose(result, [10, 15.0], rtol=1e-3) + no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0}) + self.assertAllClose(no_iters_result, [10, 0.0], rtol=1e-3) + + def testCountingLoopHandrolledC64(self): + # Define a function for the loop body + @function.Defun(dtypes.int32, dtypes.complex64) + def loop_body(step, rsum): + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + sum_out = rsum + constant_op.constant(1.5 + 2j, dtype=dtypes.complex64) + return step_out, sum_out + + # Define a function for the loop condition + @function.Defun(dtypes.int32, dtypes.complex64) + def loop_cond(step, rsum): + del rsum + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + init_sum = array_ops.placeholder(dtypes.complex64, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index, init_sum], loop_cond, + loop_body) + + result = sess.run(loop_outputs, {init_index: 0, init_sum: 0.0}) + self.assertAllClose(result[1], np.complex64(15 + 20j), rtol=1e-3) + no_iters_result = sess.run(loop_outputs, {init_index: 10, init_sum: 0.0}) + self.assertAllClose(no_iters_result[1], np.complex64(0), rtol=1e-3) + + def testLoopWithConstantOutput(self): + # Define a function for the loop body + @function.Defun(dtypes.int32, dtypes.int32) + def loop_body(step, x): + del x + step_out = step + constant_op.constant(1, dtype=dtypes.int32) + return (step_out, 7) + + # Define a function for the loop condition + @function.Defun(dtypes.int32, dtypes.int32) + def loop_cond(step, x): + del x + return step < 10 + + with self.test_session() as sess: + init_index = array_ops.placeholder(dtypes.int32, []) + with self.test_scope(): + loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) + + result = sess.run(loop_outputs, {init_index: 0}) + self.assertAllClose(result, [10, 7], rtol=1e-3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 7e1f5c76ed65946363cc3c113ab1a9862f87b289..e924fe1e61454aefda622a5a46a0e483d26db5c1 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import contextlib +import os import random import re @@ -44,6 +45,8 @@ flags.DEFINE_string('test_device', None, flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.') flags.DEFINE_string('disabled_manifest', None, 'Path to a file with a list of tests that should not run.') +flags.DEFINE_string('tf_xla_flags', None, + 'Value to set the TF_XLA_FLAGS environment variable to') class XLATestCase(test.TestCase): @@ -71,14 +74,14 @@ class XLATestCase(test.TestCase): self._all_types = set( [dtype.as_numpy_dtype for dtype in self._all_tf_types]) - self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ dtype.as_numpy_dtype for dtype in self.complex_tf_types ]) - self._numeric_types = set( - self.int_types | self._float_types | self.complex_types) + self._numeric_types = set(self._int_types | self._float_types + | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable @@ -97,6 +100,8 @@ class XLATestCase(test.TestCase): disabled_tests = [] disabled_method_types = [] for l in manifest_file.read().splitlines(): + if not l: + continue entry = comments_re.sub('', l).strip().split(' ') if len(entry) == 1: disabled_tests.append(entry[0]) @@ -113,6 +118,9 @@ class XLATestCase(test.TestCase): for name in types]) manifest_file.close() + if FLAGS.tf_xla_flags is not None: + os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags + @property def all_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) @@ -130,6 +138,11 @@ class XLATestCase(test.TestCase): name = '{}.{}'.format(type(self).__name__, self._testMethodName) return self._float_tf_types - self._method_types_filter.get(name, set()) + @property + def int_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._int_types - self._method_types_filter.get(name, set()) + @property def numeric_tf_types(self): name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index fb82c2601c432cee425a46a3b6dc2c55febeda87..ba5c3a14849cefcb680b03425232724ff32375a8 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -58,6 +58,15 @@ xla_proto_library( ], ) +xla_proto_library( + name = "host_compute_metadata_proto", + srcs = ["host_compute_metadata.proto"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "tf2xla", srcs = ["tf2xla.cc"], @@ -149,6 +158,7 @@ cc_library( ":common", ":dump_graph", ":functionalize_control_flow", + ":host_compute_metadata_proto", ":sharding_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", @@ -322,6 +332,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -404,7 +415,7 @@ cc_library( "//tensorflow/compiler/jit:graph_to_functiondef", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", - "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -426,7 +437,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", - "//tensorflow/compiler/tf2xla/cc:functional_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -452,17 +463,3 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index 311dddca94c458a60fd00afe5532840e0dbf0437..4f8bb8ad743afe69a6544c2ae0dc7309891b2df3 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -7,61 +7,23 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") tf_gen_op_wrapper_cc( - name = "functional_ops_gen", - include_internal_ops = 1, - out_ops_file = "ops/functional_ops", - deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"], + name = "xla_ops_gen", + out_ops_file = "ops/xla_ops", + deps = ["//tensorflow/compiler/tf2xla/ops:xla_ops"], ) cc_library( - name = "functional_ops", - srcs = ["ops/functional_ops.cc"], - hdrs = ["ops/functional_ops.h"], + name = "xla_ops", + srcs = ["ops/xla_ops.cc"], + hdrs = ["ops/xla_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", "//tensorflow/cc:scope", - "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) - -tf_gen_op_wrapper_cc( - name = "sendrecv_ops_gen", - include_internal_ops = 1, - out_ops_file = "ops/sendrecv_ops", - deps = ["//tensorflow/compiler/tf2xla/ops:sendrecv_ops"], -) - -cc_library( - name = "sendrecv_ops", - srcs = ["ops/sendrecv_ops.cc"], - hdrs = ["ops/sendrecv_ops.h"], - deps = [ - "//tensorflow/cc:const_op", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 82923722c54d235716b9138d95a75a441df924ca..de1008803d69fefa415c7bdbe6c27a62e625b417 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -37,7 +37,7 @@ Status BackwardsConstAnalysis(const Graph& g, }; Status status; - std::unordered_set must_be_const; + std::unordered_set must_be_const; auto visit = [&status, &metadata_ops, &must_be_const, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -55,8 +55,10 @@ Status BackwardsConstAnalysis(const Graph& g, compile_time_const_args->at(index) = true; return; } - for (Node* pred : node->in_nodes()) { - must_be_const.insert(pred); + for (const Edge* pred : node->in_edges()) { + if (!pred->IsControlEdge()) { + must_be_const.insert(pred->src()); + } } return; } diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 9d125f8d499863cfaa0e26b5b633ca02914d1b7d..992b12c06db5efc0ae54284d0ea77017c1c79aca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -79,5 +79,24 @@ TEST(ConstAnalysisTest, TopologicalOrder) { } } +TEST(ConstAnalysisTest, DontFollowControlDependencies) { + Scope root = Scope::NewRootScope(); + + Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0); + Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1); + Output c1 = + ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1}); + Output add = ops::Add(root, arg1, c1); + Output reshape = ops::Reshape(root, arg1, add); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph)); + + std::vector const_args(2, false); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + + EXPECT_EQ(const_args, std::vector({false, true})); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f8169795ddfb7fd4e93d3f136c51623385868951..16b9142cbf7d2afe99c22acbc32fb17c09b00081 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -583,13 +583,15 @@ class FunctionalizeCond { // CondArgNode represents a input to the conditional and its corresponding // switch nodes. struct CondArgNode { - explicit CondArgNode(Node* input) : input(input) {} + explicit CondArgNode(Node* src, int src_output) + : src(src), src_output(src_output) {} string ToString() const { - return strings::StrCat("input=", input->name(), + return strings::StrCat("src=", src->name(), ":", src_output, " switches=", NodesToString(switches)); } - Node* input; + Node* src; + int src_output; std::vector switches; }; using CondArgNodes = std::vector; @@ -606,14 +608,15 @@ class FunctionalizeCond { // Group of switch nodes that will be part of the same XlaIf. struct SwitchCluster { - explicit SwitchCluster(Node* predicate) : predicate(predicate) {} + explicit SwitchCluster(const Edge* predicate_edge) + : predicate_edge(predicate_edge) {} string ToString() const { - return strings::StrCat(name, " predicate=", predicate->name(), + return strings::StrCat(name, " predicate=", predicate_edge->src()->name(), " switches=", NodesToString(switches)); } string name; - Node* predicate; + const Edge* predicate_edge; std::vector switches; }; @@ -653,8 +656,8 @@ class FunctionalizeCond { Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate, - Node* if_node); + Status AddInputEdges(const CondArgNodes& cond_arg_nodes, + const Edge* predicate_edge, Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); @@ -756,8 +759,8 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, if (IsMerge(dst)) { dst_state->branch = Branch::kBoth; } else { - return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", - dst_state->ToString(), " for ", + return errors::Internal("Illegal merge:\n", src_state.ToString(), + " with ", dst_state->ToString(), " for\n", dst->DebugString()); } } @@ -861,8 +864,8 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { if (IsSwitch(n)) { Node* input; TF_CHECK_OK(n->input_node(0, &input)); - entry_cluster[n->id()] = &clusters[input->id()]; - UnionFind* cluster = find_output_cluster(input); + entry_cluster[n->id()] = find_output_cluster(input); + UnionFind* cluster = entry_cluster[n->id()]; int cluster_depth = switch_depth[cluster->Get().representative]; // Merge the inputs of the switch node with one another. This results in // predicates and control input residing in the same cluster. @@ -898,6 +901,14 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { int src_depth = switch_depth[src_id]; if (!e->IsControlEdge() || new_switch_depth == src_depth) { if (src_depth != new_switch_depth) { + // TODO(b/77601805) remove this when outside_compilation supports + // control flow. + if (str_util::StrContains(src->name(), "outside_compilation") || + str_util::StrContains(n->name(), "outside_compilation")) { + return errors::InvalidArgument( + "outside_compilation is not yet supported within TensorFlow " + "control flow constructs b/77601805"); + } return errors::InvalidArgument( "Unable to functionalize control flow in graph: Operand ('", src->name(), "') and operator ('", n->name(), @@ -956,16 +967,21 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { // node whose cluster is later in the topological order of clustered // switches). for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { - Node* pred; - TF_CHECK_OK((*it)->input_node(1, &pred)); - auto repr = std::make_pair(pred, clusters[(*it)->id()].Get()); + const Edge* pred_edge; + TF_CHECK_OK((*it)->input_edge(1, &pred_edge)); + // The predicate can be preceded by a identity node. Look through identity + // nodes to predicate. + while (pred_edge->src()->IsIdentity()) { + TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge)); + } + auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get()); if (predicate_index.find(repr) == predicate_index.end()) { predicate_index[repr] = switch_clusters.size(); - switch_clusters.emplace_back(pred); + switch_clusters.emplace_back(pred_edge); // Generate a name by concatenating with the cluster representative as // there could be multiple switch clusters with the same predicate. - switch_clusters[predicate_index[repr]].name = - strings::StrCat(pred->name(), "_", repr.second.representative, "_If"); + switch_clusters[predicate_index[repr]].name = strings::StrCat( + pred_edge->src()->name(), "_", repr.second.representative, "_If"); } switch_clusters[predicate_index[repr]].switches.push_back(*it); } @@ -1044,9 +1060,12 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( ForwardFlowNode& ffn = branch_map[out]; if (IsSwitch(n)) { int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); - TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ", + e->DebugString()); } else { - TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn), + " when joining ", e->DebugString()); } if (IsMerge(out)) { if (out->in_edges().size() == ffn.count) { @@ -1083,8 +1102,7 @@ Status FunctionalizeCond::FunctionalizeInternal() { for (auto it = predicate_switch_order.rbegin(); it != predicate_switch_order.rend(); ++it) { auto& ps = *it; - VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " (" - << ps.predicate->name() << ")"; + VLOG(3) << "Flow down from: " << ps.ToString(); std::unordered_map branch_map; std::unordered_set frontier; @@ -1097,21 +1115,29 @@ Status FunctionalizeCond::FunctionalizeInternal() { library_); TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second)); + } + }; + // Sort the merge and switch nodes using NodeCmp. The switch-nodes are // further grouped (post sorting) by input to the switch node as in the // functionalized form each input will be passed in only once. This grouping // should retain the sorted order. CondArgNodes cond_arg_nodes; - std::unordered_map input_index; std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); + std::unordered_map, int, Hash> input_index; for (Node* switch_node : ps.switches) { - Node* in; - TF_RETURN_IF_ERROR(switch_node->input_node(0, &in)); - if (input_index.find(in) == input_index.end()) { - input_index[in] = cond_arg_nodes.size(); - cond_arg_nodes.emplace_back(in); + const Edge* e; + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); + std::pair key = std::make_pair(e->src(), e->src_output()); + if (input_index.find(key) == input_index.end()) { + input_index[key] = cond_arg_nodes.size(); + cond_arg_nodes.emplace_back(key.first, key.second); } - cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node); + cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node); } std::vector merge_nodes(frontier.begin(), frontier.end()); std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); @@ -1200,11 +1226,12 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(switch_cluster.predicate->assigned_device_name()); + builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name()); // Conditional should be the first input ... - builder.Input( - NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0, - switch_cluster.predicate->output_type(0))); + builder.Input(NodeDefBuilder::NodeOut( + switch_cluster.predicate_edge->src()->name(), + switch_cluster.predicate_edge->src_output(), + switch_cluster.predicate_edge->src()->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1264,24 +1291,17 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, } Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, - Node* predicate, Node* if_node) { + const Edge* predicate_edge, + Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); int index = 0; - graph_->AddEdge(predicate, 0, if_node, index++); - for (auto& kv : cond_arg_nodes) { - bool inserted = false; - for (const Node* arg : kv.switches) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph_->AddControlEdge(in_edge->src(), if_node); - } else { - if (!inserted) { - graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, - index++); - inserted = true; - } - } + graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node, + index++); + for (auto& arg : cond_arg_nodes) { + if (arg.src_output == Graph::kControlSlot) { + graph_->AddControlEdge(arg.src, if_node); + } else { + graph_->AddEdge(arg.src, arg.src_output, if_node, index++); } } return Status::OK(); @@ -1302,10 +1322,10 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return errors::Unimplemented("Output of index (", edge->src_output(), ") of merge node ", node->name()); } - graph_->RemoveEdge(edge); int src_output = dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph_->RemoveEdge(edge); graph_->AddEdge(if_node, src_output, dst, dst_input); } } @@ -1323,7 +1343,7 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( Node * if_node, BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); TF_RETURN_IF_ERROR( - AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node)); + AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return if_node; @@ -1345,6 +1365,7 @@ Status FunctionalizeControlFlow(Graph* graph, VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); + // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index bc7276c3afd5060d6faeceb4d479416299ecc5da..e494f42e8ed254ac0c7c7a23a13728d3f015e9d3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md index 91351421bcacd26c41b5c9f98ea833730e4aef30..20179b67991d3d23d678cf1df2642e029ea037fd 100644 --- a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -186,6 +198,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -198,6 +211,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md index b9bdb829d773825005a8921f48d28b6892d8f0cd..55f0538dba7c1941dfea88e0631cd299e51f76d0 100644 --- a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -3,6 +3,7 @@ Operator | Type Constraint ------------------------------------- | --------------- `Abs` | `T={double,float,int32,int64}` +`Acos` | `T={complex64,double,float,int32,int64}` `Acosh` | `T={complex64,double,float}` `Add` | `T={complex64,double,float,int32,int64}` `AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` @@ -15,10 +16,12 @@ Operator | Type Constraint `ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asin` | `T={complex64,double,float,int32,int64}` `Asinh` | `T={complex64,double,float}` `AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` `AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan` | `T={complex64,double,float,int32,int64}` `Atan2` | `T={double,float}` `Atanh` | `T={complex64,double,float}` `AvgPool` | `T={double,float}` @@ -75,6 +78,10 @@ Operator | Type Constraint `FFT` | `FFT2D` | `FFT3D` | +`FakeQuantWithMinMaxArgs` | +`FakeQuantWithMinMaxArgsGradient` | +`FakeQuantWithMinMaxVars` | +`FakeQuantWithMinMaxVarsGradient` | `Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Floor` | `T={double,float}` `FloorDiv` | `T={complex64,double,float,int32,int64}` @@ -84,6 +91,7 @@ Operator | Type Constraint `FusedBatchNormGradV2` | `U={float}`
`T={float}` `FusedBatchNormV2` | `U={float}`
`T={float}` `Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherNd` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` `Greater` | `T={double,float,int32,int64,uint32,uint64}` `GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` @@ -117,14 +125,18 @@ Operator | Type Constraint `LogicalNot` | `LogicalOr` | `MatMul` | `T={complex64,double,float}` +`MatrixBandPart` | `Tindex={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixSetDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `MatrixTriangularSolve` | `T={complex64,double,float}` `Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` `MaxPool` | `T={double,float,int32,int64}` `MaxPool3D` | `T={float}` `MaxPool3DGrad` | `TInput={float}`
`T={float}` `MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradGrad` | `T={float}` +`MaxPoolGradGradV2` | `T={float}` `MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` `MaxPoolV2` | `T={double,float,int32,int64}` `Maximum` | `T={double,float,int32,int64}` @@ -183,6 +195,7 @@ Operator | Type Constraint `Round` | `T={complex64,double,float,int32,int64}` `Rsqrt` | `T={complex64,double,float}` `RsqrtGrad` | `T={complex64,double,float}` +`ScatterNd` | `Tindices={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Selu` | `T={double,float}` `SeluGrad` | `T={double,float}` @@ -195,6 +208,7 @@ Operator | Type Constraint `Sinh` | `T={complex64,double,float}` `Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Snapshot` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` `Softmax` | `T={double,float}` `SoftmaxCrossEntropyWithLogits` | `T={double,float}` `Softplus` | `T={double,float,int32,int64,uint32,uint64}` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 058a1f2621c64a735bd9d9c9d0ae007f93aa4dea..b20c1ffc7d8956f3f5530ee63e9b711a26439be5 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -130,7 +130,7 @@ Status GraphCompiler::Compile() { // Set up inputs from outputs of previous nodes. for (auto* e : n->in_edges()) { if (e->IsControlEdge()) continue; - Node* src = e->src(); + const Node* src = e->src(); TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto new file mode 100644 index 0000000000000000000000000000000000000000..43ab371a217e6c4521a160715104c96e3c8782c6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/host_compute_metadata.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; + +package tensorflow.tf2xla; +option cc_enable_arenas = true; +option java_outer_classname = "Tf2XlaProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.tf2xla"; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// TensorMetadata indicates the type and shape of a Tensor that is +// part of a host compute transfer. +message TensorMetadata { + DataType type = 1; + TensorShapeProto shape = 2; +} + +// HostTransferMetadata describes a transfer either from host to device +// or device to host. It has a key that is unique to the computation, +// and metadata about the list of tensors being transferred. +message HostTransferMetadata { + // The key used to identify this transfer. + string key = 1; + + // For each Tensor being transferred, its type and shape. + repeated TensorMetadata metadata = 2; +} + +// HostComputeMetadata describes all the sends and recvs +// from all host compute transfer ops in a computation. +message HostComputeMetadata { + // Metadata about each device_to_host transfer + repeated HostTransferMetadata device_to_host = 1; + + // Metadata about each host_to_device transfer + repeated HostTransferMetadata host_to_device = 2; +} diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index d2fa933cf9c085f92b2f442827a94d72938e4bb2..579b66969990017688477443115cc4f61c18fe4a 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -29,6 +29,7 @@ tf_kernel_library( "cwise_ops.h", "depthtospace_op.cc", "diag_op.cc", + "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", "elu_op.cc", "extract_image_patches_op.cc", @@ -56,6 +57,7 @@ tf_kernel_library( "pooling_ops.cc", "quantize_and_dequantize_op.cc", "random_ops.cc", + "reduce_window_op.cc", "reduction_ops.cc", "reduction_ops.h", "reduction_ops_common.cc", @@ -93,6 +95,7 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -102,7 +105,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:while_loop", - "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -145,7 +148,7 @@ tf_kernel_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:framework", @@ -154,6 +157,39 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "if_op", + srcs = ["if_op.cc"], + hdrs = ["if_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# Kernels that have a dummy (no-op) implementation. +tf_kernel_library( + name = "xla_dummy_ops", + srcs = [ + "assert_op.cc", + "check_numerics_op.cc", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:logging_ops_op_lib", + ], + alwayslink = 1, +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. tf_kernel_library( @@ -200,17 +236,3 @@ cc_library( ], alwayslink = 1, ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/kernels/assert_op.cc b/tensorflow/compiler/tf2xla/kernels/assert_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..af4ab5e8ef6e268226edc90515706405ac36858c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/assert_op.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +namespace { + +// This TensorFlow op supports the Assert primitve. +class AssertOp : public XlaOpKernel { + public: + explicit AssertOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + ~AssertOp() override {} + + void Compile(XlaOpKernelContext* ctx) override { + static mutex mu(tensorflow::LINKER_INITIALIZED); + static int log_counter = 0; + + mutex_lock l(mu); + if (log_counter < 20) { + ++log_counter; + LOG(WARNING) << "Ignoring Assert operator " << name(); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(AssertOp); +}; + +REGISTER_XLA_OP(Name("Assert"), AssertOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index a249b1869f547f8e5aa725f9f5cf391b10429928..931175be1111ed5f70afbdf351ee53c59c1367de 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -118,30 +118,24 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); - - auto grad_backprop = ctx->Input(0); - auto activations = ctx->Input(1); - auto scale = ctx->Input(2); - auto mean = ctx->Input(3); - auto var = ctx->Input(4); - - TensorShape input_shape = ctx->InputShape(0); - int feature_index = - GetTensorFeatureDimIndex(input_shape.dims(), data_format_); - + xla::ComputationBuilder* const b = ctx->builder(); DataType input_dtype = ctx->input_type(0); DataType scale_dtype = ctx->input_type(2); - xla::PrimitiveType input_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); - xla::PrimitiveType scale_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); // TODO(b/69928690): support mixed precision in the XLA batch normalization // operators. For now, cast everything to the statistics type (which // may be more precise than the input type). - grad_backprop = b->ConvertElementType(grad_backprop, scale_type); - activations = b->ConvertElementType(activations, scale_type); + auto grad_backprop = + XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype); + auto activations = + XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype); + auto scale = ctx->Input(2); + auto mean = ctx->Input(3); + auto var = ctx->Input(4); + + const int input_dims = ctx->InputShape(0).dims(); + const int feature_index = + GetTensorFeatureDimIndex(input_dims, data_format_); xla::ComputationDataHandle x_backprop; xla::ComputationDataHandle scale_backprop; @@ -156,7 +150,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { offset_backprop = b->GetTupleElement(output, 2); } else { // Reduce over all dimensions except the feature dim. - std::vector reduction_dims(input_shape.dims() - 1); + std::vector reduction_dims(input_dims - 1); std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, 0); std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), @@ -165,9 +159,14 @@ class FusedBatchNormGradOp : public XlaOpKernel { // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + // epsilon)) // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - offset_backprop = - b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), - *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(scale_dtype); + auto converted = + XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); // scratch1 = rsqrt(pop_var + epsilon) auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); @@ -175,17 +174,21 @@ class FusedBatchNormGradOp : public XlaOpKernel { b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); // scratch2 = sum(y_backprop * (x - mean)) - auto scratch2 = b->Reduce( - b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), - XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), - reduction_dims); + auto mul = + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})); + converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type); + reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduction_dims); + auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype); x_backprop = b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(0, + XlaHelpers::ConvertElementType(b, x_backprop, input_dtype)); ctx->SetOutput(1, scale_backprop); ctx->SetOutput(2, offset_backprop); ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 344a2ab2b6835c518c41de6f7a30fb2a34d130d2..569950c2dfaeb61028049a263a962dfa54a62e09 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,7 +159,9 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), +REGISTER_XLA_OP(Name("BatchToSpaceND") + .CompileTimeConstInput("block_shape") + .CompileTimeConstInput("crops"), BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { @@ -182,9 +184,7 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace") - .CompileTimeConstInput("crops") - .CompileTimeConstInput("block_shape"), +REGISTER_XLA_OP(Name("BatchToSpace").CompileTimeConstInput("crops"), BatchToSpaceOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index c667b4e3e326b776faba49387760abbd582fcc68..ed33b8ed2e823f313a9a7fe220390bc617288405 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -103,10 +103,15 @@ class BiasAddGradOp : public XlaOpKernel { std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), feature_dim + 1); - xla::ComputationDataHandle result = ctx->builder()->Reduce( - ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), reduce_dims); - ctx->SetOutput(0, result); + xla::ComputationBuilder* const b = ctx->builder(); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), reduce_dims); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0))); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 43a6a747c6bcc441f33f276fde4a66f367d99731..c52b2dcb7e9ef81fd52565dfbda05e33a52ed43a 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -62,5 +62,50 @@ class CastOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Cast"), CastOp); +class BitcastOp : public XlaOpKernel { + public: + explicit BitcastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &src_dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("type", &dst_dtype_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle output; + + if (src_dtype_ == dst_dtype_) { + output = input; + } else { + // The only complex type in XLA is C64, so error out if the bitcast has a + // complex source or destination type and the bitcast is not trivial. + OP_REQUIRES(ctx, + !xla::primitive_util::IsComplexType(src_type_) && + !xla::primitive_util::IsComplexType(dst_type_), + errors::Unimplemented("Complex types not supported.")); + // XLA bitcast requires that the bit-width of the source and destination + // matches, and currently only the simple lowering is performed. + OP_REQUIRES(ctx, + xla::primitive_util::BitWidth(src_type_) == + xla::primitive_util::BitWidth(dst_type_), + errors::Unimplemented( + "Only bitcasts between equally sized types supported.")); + output = builder->BitcastConvertType(input, dst_type_); + } + + ctx->SetOutput(0, output); + } + + protected: + DataType src_dtype_, dst_dtype_; + xla::PrimitiveType src_type_, dst_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(BitcastOp); +}; + +REGISTER_XLA_OP(Name("Bitcast"), BitcastOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6061e822d8d9c6c807a63aad4e9e9526a49e456c --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/check_numerics_op.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace { + +class CheckNumericsOp : public XlaOpKernel { + public: + explicit CheckNumericsOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + // TODO(b/32223192): add a real implementation of CheckNumerics + { + static mutex mu(tensorflow::LINKER_INITIALIZED); + static int log_counter = 0; + mutex_lock l(mu); + if (log_counter < 20) { + ++log_counter; + LOG(WARNING) << "Ignoring CheckNumerics operator " << name(); + } + } + ctx->SetOutput(0, ctx->Input(0)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CheckNumericsOp); +}; + +REGISTER_XLA_OP(Name("CheckNumerics"), CheckNumericsOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 81cea6d376d02c956a5257c5475fe5c10b83deb9..c0ee0c9c2ea849a692bee70bba36d32335eed9b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -58,7 +58,7 @@ xla::ComputationDataHandle CreateExpandedZero( // Create a mask for depthwise convolution that will make a normal convolution // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// depthwise filter this returns a [2, 2, 3, 6] tensor // 1 1 0 0 0 0 1 1 0 0 0 0 // 0 0 1 1 0 0 0 0 1 1 0 0 // 0 0 0 0 1 1 0 0 0 0 1 1 @@ -166,6 +166,10 @@ xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); return builder->Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with + // ExpandedZero guarantees that only one element is non zero, so there + // cannot be accumulated precision error. builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), *ctx->GetOrCreateAdd(dtype), {expanded_filter_shape.dims() - 2}), diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..800ef5ab98d70ad822c6efffb33db28b46ae50fe --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class DynamicUpdateSliceOp : public XlaOpKernel { + public: + explicit DynamicUpdateSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + VLOG(3) << "DynamicUpdateSliceOp::Compile"; + + DataType index_type = input_type(2); + OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("index must be int32 or int64")); + + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape update_shape = ctx->InputShape(1); + const TensorShape index_shape = ctx->InputShape(2); + + OP_REQUIRES( + ctx, + TensorShapeUtils::IsVector(index_shape) && + index_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument("index must be a vector with length equal to " + "the number of input dimensions")); + OP_REQUIRES( + ctx, input_shape.dims() == update_shape.dims(), + errors::InvalidArgument("input and update must have the same rank," + " input shape is ", + input_shape.DebugString(), "; update shape is ", + update_shape.DebugString())); + + xla::ComputationDataHandle result = ctx->builder()->DynamicUpdateSlice( + ctx->Input(0), ctx->Input(1), ctx->Input(2)); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 453a32c494b42e9922bc35fc526f3306530054fd..99470d70e709ddb5593c5eaae061bb897befc168 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -247,6 +247,8 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { const TensorShape gradient_shape = ctx->InputShape(0); xla::ComputationDataHandle input = ctx->Input(1); const DataType data_type = ctx->input_type(1); + const DataType accumulation_type = + XlaHelpers::SumAccumulationType(data_type); xla::ComputationDataHandle input_min = ctx->Input(2); xla::ComputationDataHandle input_max = ctx->Input(3); @@ -265,15 +267,23 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { ctx->SetOutput(0, output0); xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); + xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes); + xla::ComputationDataHandle reduce1 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select1, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); xla::ComputationDataHandle output1 = - b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); + xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes); + xla::ComputationDataHandle reduce2 = b->ReduceAll( + XlaHelpers::ConvertElementType(b, select2, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type)); xla::ComputationDataHandle output2 = - b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, - *ctx->GetOrCreateAdd(data_type)); + XlaHelpers::ConvertElementType(b, reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eefbe55c815d80a608bdf62d454a69d722adb158 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -0,0 +1,226 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/if_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &name_attr)); + then_branch_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &name_attr)); + else_branch_ = *name_attr; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); +} + +// TODO(b/35949885): There is duplication here with the handling of the +// while_op. Refactor the common code out/rework. +void XlaIfOp::Compile(XlaOpKernelContext* ctx) { + xla::ComputationBuilder* b = ctx->builder(); + + OP_REQUIRES(ctx, cond_type_ == DT_BOOL, + errors::InvalidArgument( + "Condition argument must be a boolean for XLA compilation")); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)), + errors::InvalidArgument( + "Condition argument must be a scalar for XLA compilation")); + + VLOG(1) << "Building If: " << input_types_.size() << " inputs"; + + std::vector inputs(input_types_.size()); + std::vector arguments(input_types_.size()); + for (int i = 0; i < input_types_.size(); ++i) { + XlaCompiler::Argument& arg = arguments[i]; + DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); + + arg.initialized = resource->initialized(); + arg.kind = XlaCompiler::Argument::kResource; + arg.resource_kind = resource->kind(); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + + arg.type = resource->type(); + arg.shape = resource->shape(); + OP_REQUIRES(ctx, arg.initialized, + errors::Unimplemented("Uninitialized arguments: ", arg.name)); + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + arg.name = resource->name(); + VLOG(2) << "Resource " << resource->name() + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = input_types_[i]; + arg.shape = ctx->InputShape(i + 1); + inputs[i] = ctx->Input(i + 1); + VLOG(2) << "Arg type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString(); + } + } + + // Compile both branches of the conditional. + XlaCompiler::CompileOptions options; + options.use_tuple_arg = true; + options.resolve_compile_time_constants = false; + options.return_updated_values_for_all_resources = true; + options.is_entry_computation = false; + XlaCompiler* compiler = ctx->compiler(); + + XlaCompiler::CompilationResult then_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + XlaCompiler::CompilationResult else_result; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + } + } + + // Check that both branches have identical input shapes. + OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape then_input_shape = then_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(then_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape else_input_shape = else_result.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(else_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), + errors::InvalidArgument( + "Input shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_input_shape), " vs. ", + xla::ShapeUtil::HumanString(else_input_shape))); + + // Check that both branches have identical output shapes. + OP_REQUIRES( + ctx, + xla::ShapeUtil::Compatible(then_result.xla_output_shape, + else_result.xla_output_shape), + errors::InvalidArgument( + "Output shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", + xla::ShapeUtil::HumanString(else_result.xla_output_shape))); + + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString(then_result.xla_output_shape); + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + OP_REQUIRES(ctx, + then_result.resource_updates.size() == + else_result.resource_updates.size(), + errors::FailedPrecondition( + "Different number of resources in then and else branch")); + for (int i = 0; i < then_result.resource_updates.size(); ++i) { + const auto& lhs = then_result.resource_updates[i]; + const auto& rhs = else_result.resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + OP_REQUIRES( + ctx, equal, + errors::FailedPrecondition( + "Mismatch in resource of then and else branch for resource ", i)); + } + + xla::ComputationDataHandle outputs = + b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, + b->Tuple(inputs), *else_result.computation); + // Sets non-variable outputs. + for (int i = 0; i < output_types_.size(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i); + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Setting output " << i; + auto shape_or = b->GetShape(output_handle); + if (shape_or.ok()) { + LOG(INFO) << "Shape for output " << i << ": " + << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie()); + } else { + LOG(INFO) << "Shape unknown for output " << i; + } + } + ctx->SetOutput(i, output_handle); + } + } + + // Updates the values of any resource variables modified by the conditional + // bodies. + for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { + for (int i = 0; i < result->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = result->resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index + 1, &resource)); + if (update.modified) { + int pos = result->outputs.size() + i; + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + b->GetTupleElement(outputs, pos), b)); + } + VLOG(2) << "If variable: pos: " << update.input_index + << " name: " << resource->name() + << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + } + } + VLOG(1) << "Done building If"; +} + +REGISTER_XLA_OP(Name("XlaIf").AllowResourceTypes(), XlaIfOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f9bc98a198a72dcc0594e61971713bf890ce30b6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional conditional primitive. +// +// The outputs of the then/else branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in then/else bodies may read from and write to resource +// variables. +// Resource variables may be passed as arguments to the then/else function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the then/else bodies output. This ensures the then/else bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaIfOp : public XlaOpKernel { + public: + explicit XlaIfOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaIfOp); + + NameAttrList then_branch_; + NameAttrList else_branch_; + DataType cond_type_; + DataTypeVector input_types_; + DataTypeVector output_types_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index f22f384256a8ddd8c05de4a1322aba741dc4d7fd..5eeda79a935e8194a596d322b52add27846d378c 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -180,9 +180,13 @@ class AdjustContrastOpV2 : public XlaOpKernel { DataType type = context->input_type(0); - auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), - /*computation=*/*context->GetOrCreateAdd(type), + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, input, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *context->GetOrCreateAdd(accumulation_type), {height_dim, width_dim}); + auto output = XlaHelpers::ConvertElementType(b, reduce, type); output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); std::vector broadcast_dims(input_shape.dims() - 2); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index d096415087e47a73503a06526ab133ac34803c5d..c177f08d9c4687bb13b98a4328bb3960519799c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -29,21 +29,22 @@ class L2LossOp : public XlaOpKernel { explicit L2LossOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); + std::vector dims(ctx->InputShape(0).dims()); + std::iota(dims.begin(), dims.end(), 0); DataType dtype = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - - auto zero = XlaHelpers::Zero(b, dtype); - auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); - const xla::Computation& add = *ctx->GetOrCreateAdd(dtype); - - std::vector dims(input_shape.dims()); - std::iota(dims.begin(), dims.end(), 0); + xla::ComputationBuilder* const b = ctx->builder(); // output = sum(t ** 2) / 2 - auto x = ctx->Input(0); - ctx->SetOutput(0, b->Div(b->Reduce(b->Mul(x, x), zero, add, dims), two)); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto t = + XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type); + auto square = b->Mul(t, t); + auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), dims); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype); + auto two = XlaHelpers::IntegerLiteral(b, dtype, 2); + ctx->SetOutput(0, b->Div(deconverted, two)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 759d1a1a2d996d4f5deb1774be7014bb6de30f40..1cfee3070f384af0a7441a9c860c530dd1b42187 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -47,12 +47,17 @@ class LRNOp : public XlaOpKernel { // We use a window of depth_radius_ * 2 + 1, to account for the current // element and a depth_radius_ on either side. - auto squared = builder->Mul(input, input); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, input, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto scale = builder->Pow( builder->Add(builder->ConstantR0(bias_), @@ -130,12 +135,17 @@ class LRNGradOp : public XlaOpKernel { // dyi *= out_grads[j] // grads[k] += dyi - auto squared = builder->Mul(in_image, in_image); - auto sqr_sum = builder->ReduceWindow( - squared, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); + auto converted = + XlaHelpers::ConvertElementType(builder, in_image, accumulation_type); + auto squared = builder->Mul(converted, converted); + auto reduce = builder->ReduceWindow( + squared, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto sqr_sum = + XlaHelpers::ConvertElementType(builder, reduce, input_type(0)); auto norm = builder->Add(builder->ConstantR0(bias_), @@ -146,11 +156,15 @@ class LRNGradOp : public XlaOpKernel { builder->Div(out_image, norm)), in_grads); - auto dy_reduced = builder->ReduceWindow( - dy, XlaHelpers::Zero(builder, input_type(0)), - *ctx->GetOrCreateAdd(input_type(0)), + auto converted_dy = + XlaHelpers::ConvertElementType(builder, dy, accumulation_type); + auto dy_reduce = builder->ReduceWindow( + converted_dy, XlaHelpers::Zero(builder, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); + auto dy_reduced = + XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); xla::ComputationDataHandle gradients = builder->Add( builder->Mul(in_image, dy_reduced), diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4fb5dd4e06c7c70591262c0d63a91c383a2a6e0..5f635dd1bc6122cfcac8163baafd95b13f157715 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -35,8 +35,11 @@ namespace { // Superclass of pooling ops. class PoolingOp : public XlaOpKernel { public: - PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) - : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims, + const DataType reduction_type) + : XlaOpKernel(ctx), + num_spatial_dims_(num_spatial_dims), + reduction_type_(reduction_type) { if (ctx->num_inputs() == 1) { std::vector ksize_int; std::vector stride_int; @@ -63,12 +66,10 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } // Method that builds an initial value to use in reductions. - virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) = 0; + virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0; // The reduction operation to apply to each window. - virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) = 0; + virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0; // A post-processing operation to apply on the outputs of the ReduceWindow. virtual xla::ComputationDataHandle PostProcessOutput( @@ -76,9 +77,6 @@ class PoolingOp : public XlaOpKernel { DataType dtype, const TensorShape& input_shape) = 0; void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - std::vector ksize = ksize_; std::vector stride = stride_; if (ctx->num_inputs() != 1) { @@ -106,16 +104,20 @@ class PoolingOp : public XlaOpKernel { stride.clear(); OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); } + const TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), " dimensions")); - const DataType type = input_type(0); - xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( - input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, - stride, padding_); - ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); + xla::ComputationBuilder* const b = ctx->builder(); + auto input = + XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); + auto reduce = ctx->builder()->ReduceWindow( + input, InitValue(b), *Reduction(ctx), ksize, stride, padding_); + auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + ctx->SetOutput(0, + PostProcessOutput(ctx, pooled, input_type(0), input_shape)); } protected: @@ -124,21 +126,21 @@ class PoolingOp : public XlaOpKernel { std::vector stride_; xla::Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; + DataType reduction_type_; }; class MaxPoolOp : public PoolingOp { public: MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ctx->input_type(0)) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::MinValue(b, data_type); + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + return XlaHelpers::MinValue(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateMax(dtype); + const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateMax(reduction_type_); } xla::ComputationDataHandle PostProcessOutput( @@ -209,15 +211,17 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( } // Build a matrix of all 1s, with the same width/height as the input. + const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); auto ones = ctx->builder()->Broadcast( - XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); + XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); // Perform a ReduceWindow with the same window size, strides, and padding // to count the number of contributions to each result element. - auto counts = ctx->builder()->ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), dtype), - *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, + auto reduce = ctx->builder()->ReduceWindow( + ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, xla::Padding::kSame); + auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); return ctx->builder()->Div(output, counts, window_dims); } @@ -226,16 +230,16 @@ static xla::ComputationDataHandle AvgPoolDivideByCount( class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) - : PoolingOp(ctx, num_spatial_dims) {} + : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, + /*reduction_type=*/ + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, - DataType data_type) override { - return XlaHelpers::Zero(b, data_type); + xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + return XlaHelpers::Zero(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx, - DataType dtype) override { - return ctx->GetOrCreateAdd(dtype); + const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + return ctx->GetOrCreateAdd(reduction_type_); } xla::ComputationDataHandle PostProcessOutput( @@ -455,14 +459,12 @@ class AvgPoolGradOp : public XlaOpKernel { gradients_shape, filter_shape, out_backprop_shape, stride_, padding_, data_format_, &dims)); + // The input gradients are computed by a convolution of the output gradients + // and the filter, with some appropriate padding. See the comment at the top + // of conv_grad_ops.h for details. + xla::ComputationBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - - // The input gradients are computed by a convolution of the output - // gradients - // and the filter, with some appropriate padding. See the comment at - // the top of conv_grad_ops.h for details. - DataType dtype = input_type(1); - + auto dtype = input_type(1); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; @@ -483,17 +485,18 @@ class AvgPoolGradOp : public XlaOpKernel { padding->set_interior_padding(dims.spatial_dims[i].stride - 1); } - auto zero = XlaHelpers::Zero(ctx->builder(), dtype); - auto padded_gradients = - ctx->builder()->Pad(out_backprop_div, zero, padding_config); + auto zero = XlaHelpers::Zero(b, dtype); + auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config); // in_backprop = padded_gradients ones std::vector ones(num_dims(), 1LL); - xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( - padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, + auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); + auto in_backprop = b->ReduceWindow( + XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), ksize_, /* window_strides=*/ones, xla::Padding::kValid); - - ctx->SetOutput(0, in_backprop); + ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); } protected: @@ -525,5 +528,172 @@ class AvgPool3DGradOp : public AvgPoolGradOp { REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), AvgPool3DGradOp); +class MaxPoolGradGradOp : public XlaOpKernel { + public: + MaxPoolGradGradOp(OpKernelConstruction* ctx, int num_spatial_dims) + : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + + OP_REQUIRES(ctx, ksize_.size() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES(ctx, stride_.size() == num_dims(), + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims(), " dimensions")); + + const TensorShape tensor_in_shape = ctx->InputShape(0); + const TensorShape tensor_out_shape = ctx->InputShape(1); + const TensorShape out_backprop_shape = ctx->InputShape(2); + + // For maxpooling, tensor_in should have num_dims() dimensions. + OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_in must be ", num_dims(), + "-dimensional")); + OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), + errors::InvalidArgument("tensor_out must be ", num_dims(), + "-dimensional")); + // For maxpooling, out_backprop should have num_dims() dimensions. + OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), + errors::InvalidArgument("out_backprop must be ", num_dims(), + "-dimensional")); + + // What we want to compute: + // Given y = MaxPool(x), and xs_grad = MaxPoolGrad(x, y, ys_grad) + // MaxPoolGradGrad computes {ys_grad}_grad given x, y, and {xs_grad}_grad. + // + // In the regular TF op, this amounts to selecting for each window the + // incoming backprop value from xs_grad_grad that corresponds to the maximal + // value in the corresponding window of x. + // + // TODO(b/73062247): What we really want is a ReduceWindow with different + // arrays for index selection vs return value selection--a select-to-gather. + // + // Here, we implement a bitwise hack: we use the hi 16 bits of input for + // separate max pooling alongside each of the hi and lo 16 bits of + // out_backprop packed into 16 lo bits, which we then glue back together at + // the end to get a full 32 bits of gradient. + // + // This could select the wrong backprop value for two x values that are + // equally maximal up to the first 16 bits, in which case we are taking the + // latter. + // + // Note that in principle we could use 32 separate maxpools to recover each + // of 32 bits of the gradient while preserving 31 bits of input for the max + // pooling criteria; here, we just truncate to the first 16 bits of input. + + auto input = ctx->Input(0); + auto out_backprop = ctx->Input(2); + + auto b = ctx->builder(); + + auto sixteen = b->ConstantR0(16); + // in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32 + auto in_hi = b->BitcastConvertType( + b->ConvertElementType(b->ConvertElementType(input, xla::BF16), + xla::F32), + xla::U32); + auto bp_int = b->BitcastConvertType(out_backprop, xla::U32); + auto bp_hi = b->ShiftRightLogical(bp_int, sixteen); + auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen); + auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add. + auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add. + + auto init_value = XlaHelpers::MinValue(b, DT_FLOAT); + // We will reduce by taking the maximal value up to 16 bits (ignoring the lo + // 16 bits of packed-in hi/lo backprop value). + auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits"); + { + // F32 parameters to satisfy lowering type restriction for reduce opcode. + const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {}); + auto lhs = rb->Parameter(0, scalar, "lhs"); + auto rhs = rb->Parameter(1, scalar, "rhs"); + auto sixteen = rb->ConstantR0(16); + auto lhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen), + sixteen); + auto rhs_criteria = rb->ShiftLeft( + rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen), + sixteen); + // Must use a F32 comparison, because S32 would not work for negatives. + rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32), + rb->BitcastConvertType(rhs_criteria, xla::F32)), + lhs, rhs); + } + auto reduce = rb->BuildAndNoteError(); + xla::Padding xla_padding = + (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; + auto pooled_hi = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto pooled_lo = + b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32), + init_value, reduce, ksize_, stride_, xla_padding); + auto grads_hi = + b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen); + auto grads_lo = b->ShiftRightLogical( + b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen), + sixteen); + auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add. + + xla::PrimitiveType element_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); + ctx->SetOutput(0, b->BitcastConvertType(grads, element_type)); + } + + protected: + const int num_spatial_dims_; + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_ = FORMAT_NHWC; +}; + +class MaxPool2DGradGradOp : public MaxPoolGradGradOp { + public: + explicit MaxPool2DGradGradOp(OpKernelConstruction* ctx) + : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/2) { + string data_format; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); + OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + } +}; +REGISTER_XLA_OP(Name("MaxPoolGradGrad").TypeConstraint("T", DT_FLOAT), + MaxPool2DGradGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradGradV2") + .TypeConstraint("T", DT_FLOAT) + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradGradOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb144bea9e429b7c8bcc3d07f688ed6a254c3be0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class ReduceWindowOp : public XlaOpKernel { + public: + explicit ReduceWindowOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); + OP_REQUIRES_OK(context, + context->GetAttr("window_dimensions", &window_dimensions_)); + OP_REQUIRES_OK(context, + context->GetAttr("window_strides", &window_strides_)); + OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); + OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions_.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions_.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides_.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides_.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_low_.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low_.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high_.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high_.size(), " vs. ", rank, ")")); + + xla::ComputationBuilder* builder = context->builder(); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *computation_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES(context, + xla::ShapeUtil::Compatible( + reducer.xla_output_shape, + xla::ShapeUtil::MakeTupleShape({scalar_shape})), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + // Wraps the reducer in a computation that unpacks the output tuple. + xla::Computation wrapper; + { + std::unique_ptr cb = + builder->CreateSubBuilder("wrapper"); + auto x = cb->Parameter(0, scalar_shape, "x"); + auto y = cb->Parameter(1, scalar_shape, "y"); + auto outputs = cb->Call(*reducer.computation, {x, y}); + cb->GetTupleElement(outputs, 0); + xla::StatusOr result = cb->Build(); + OP_REQUIRES_OK(context, result.status()); + wrapper = std::move(result.ValueOrDie()); + } + + std::vector> padding(rank); + for (int i = 0; i < rank; ++i) { + padding[i] = {padding_low_[i], padding_high_[i]}; + } + + xla::ComputationDataHandle output = builder->ReduceWindowWithGeneralPadding( + context->Input(0), context->Input(1), wrapper, window_dimensions_, + window_strides_, padding); + context->SetOutput(0, output); + } + + private: + const NameAttrList* computation_; + std::vector window_dimensions_; + std::vector window_strides_; + std::vector padding_low_; + std::vector padding_high_; + + TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); +}; + +REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 03b13b2924f4b81c1017804c91d5ffb81c44ea0b..812d258cd1677e18ef49952044126c76a2f55b19 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -27,7 +27,13 @@ namespace { class SumOp : public XlaReductionOp { public: - explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit SumOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { @@ -39,11 +45,13 @@ REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: - explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit ProdOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - return XlaHelpers::One(builder, input_type(0)); + return XlaHelpers::One(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -58,13 +66,12 @@ REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), class MinOp : public XlaReductionOp { public: - explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MinOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MaxValue(type)); + return XlaHelpers::MaxValue(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -78,13 +85,12 @@ REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: - explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MaxOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { - xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::Literal::MinValue(type)); + return XlaHelpers::MinValue(builder, reduction_type_); } void BuildReducer(xla::ComputationBuilder* builder, @@ -98,8 +104,14 @@ REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: - explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit MeanOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, + XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} + xla::ComputationDataHandle InitialValue( + xla::ComputationBuilder* builder) override { + return XlaHelpers::Zero(builder, reduction_type_); + } void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { @@ -121,7 +133,8 @@ REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), class AllOp : public XlaReductionOp { public: - explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AllOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { @@ -139,7 +152,8 @@ REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: - explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {} + explicit AnyOp(OpKernelConstruction* ctx) + : XlaReductionOp(ctx, ctx->input_type(0)) {} xla::ComputationDataHandle InitialValue( xla::ComputationBuilder* builder) override { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 9aca6d8fedf92f176b3b7b40c5961d4a2e557a8a..f3181f0dadc2d3f45abb145e009e2663c10490f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -33,12 +33,12 @@ namespace tensorflow { // xla::ComputationBuilder. class XlaReductionOp : public XlaOpKernel { public: - explicit XlaReductionOp(OpKernelConstruction* ctx); + XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); ~XlaReductionOp() override {} - // Return the base case for the reduction. Defaults to zero. + // Return the base case for the reduction. virtual xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder); + xla::ComputationBuilder* builder) = 0; // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired @@ -63,6 +63,9 @@ class XlaReductionOp : public XlaOpKernel { private: // True if the number of dimensions should be maintained. bool keep_dims_; + + protected: + DataType reduction_type_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4b5d09eb9fd4110cdc4221099ff55767e9132540..64fe765ae9a945c58ea60bc157b1520c83b0d8e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -24,19 +24,15 @@ limitations under the License. namespace tensorflow { -XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { +XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, + DataType reduction_type) + : XlaOpKernel(ctx), reduction_type_(reduction_type) { const DataType dt = BaseType(input_type(0)); OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); } -// Return the base case for the reduction. Defaults to zero. -xla::ComputationDataHandle XlaReductionOp::InitialValue( - xla::ComputationBuilder* builder) { - return XlaHelpers::Zero(builder, input_type(0)); -} - // Unless BuildFinalizer is overridden the reduction has no // finalizer. xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( @@ -100,36 +96,26 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { string desc = ctx->op_kernel().name(); - // Call virtual method to get the initial value. - const xla::ComputationDataHandle initial = InitialValue(ctx->builder()); + xla::ComputationBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::ComputationBuilder r(ctx->builder()->client(), - strings::StrCat(desc, "-reduction")); + xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction")); xla::PrimitiveType type; - TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - // Make two scalar parameters of the desired type for the lambda. - xla::ComputationDataHandle rx = - r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); - xla::ComputationDataHandle ry = - r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); - - auto data = ctx->Input(0); + TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); + auto data = b->ConvertElementType(ctx->Input(0), type); + // Call virtual method to get the initial value. + auto initial = b->ConvertElementType(InitialValue(b), type); + // Make two scalar parameters of the desired type for the lambda. + auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x"); + auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); - xla::ComputationDataHandle reduce = - ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes); - xla::ComputationDataHandle finalized = - BuildFinalizer(ctx->builder(), reduce, num_elements_reduced); - - xla::ComputationDataHandle result; - if (keep_dims_) { - result = ctx->builder()->Reshape(finalized, final_shape); - } else { - result = finalized; - } + auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); + auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); + auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced); + auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized; ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ee4a94164c4a43828eb4feedbfa9d1a9e231ef8f..4cfa28a0ce3d7d1f24196ef6ef2775f840b2bcf1 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -66,7 +66,7 @@ class ScanOp : public XlaOpKernel { -input_shape.dims(), ", ", input_shape.dims(), "), but got ", axis)); - DataType dtype = ctx->input_type(0); + DataType dtype = XlaHelpers::SumAccumulationType(ctx->input_type(0)); if (input_shape.num_elements() == 0) { // Exit early if there is nothing to compute. @@ -91,7 +91,6 @@ class ScanOp : public XlaOpKernel { std::swap(padding[axis].first, padding[axis].second); } - xla::ComputationDataHandle input = ctx->Input(0); xla::ComputationDataHandle init; const xla::Computation* reducer; if (sum_) { @@ -102,7 +101,10 @@ class ScanOp : public XlaOpKernel { reducer = ctx->GetOrCreateMul(dtype); } auto output = builder->ReduceWindowWithGeneralPadding( - ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, + *reducer, window_dims, window_strides, padding); + output = + XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); // In exclusive mode, we have computed an extra element containing the sum // of all the input elements. Slice off this extra "last" element. diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 80d6df6c48b0141734dcee1c2a3c413926931feb..498342a98881df0c6ff50007eacc1d5ef6196b57 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -83,7 +83,9 @@ class UnsortedSegmentSum : public XlaOpKernel { DataType dtype_; }; -REGISTER_XLA_OP(Name("UnsortedSegmentSum"), UnsortedSegmentSum); +REGISTER_XLA_OP( + Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"), + UnsortedSegmentSum); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 5172781c0d05b6682fe92086654e3b86961949ee..d079b89861817a5639ac72b5ee49d76cb4506ae8 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -48,7 +48,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) { ctx->builder()->Send(ctx->Input(0), channel); } -REGISTER_XLA_OP(Name("_XLASend"), SendOp); +REGISTER_XLA_OP(Name("XlaSend"), SendOp); class RecvOp : public XlaOpKernel { public: @@ -68,7 +68,7 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { TensorShape tensor_shape; DataType dtype; OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &tensor_shape)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype)); OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tensor_shape, &shape_)); } @@ -79,7 +79,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel)); } -REGISTER_XLA_OP(Name("_XLARecv"), RecvOp); +REGISTER_XLA_OP(Name("XlaRecv"), RecvOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 750a4c2dec8154f97f307978b3d8884271292279..463788b8b461c370a8e7ab4d79a94fc0143b8b45 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -28,7 +29,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = StringPiece(type_string()).starts_with("Log"); + log_ = str_util::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { @@ -42,9 +43,8 @@ class SoftmaxOp : public XlaOpKernel { const DataType type = input_type(0); auto logits = ctx->Input(0); - xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationBuilder* const b = ctx->builder(); const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = @@ -52,21 +52,20 @@ class SoftmaxOp : public XlaOpKernel { // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim}); - xla::ComputationDataHandle softmax; - if (log_) { - // softmax = shifted_logits - log(sum(exp(shifted_logits))) - auto log_sum_exp = - b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type), - add_func, {kClassDim})); - softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); - } else { - // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) - auto exp_shifted = b->Exp(shifted_logits); - auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func, - {kClassDim}); - softmax = b->Div(exp_shifted, sum_exp, {kBatchDim}); - } - + auto exp_shifted = b->Exp(shifted_logits); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + auto reduce = + b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum = XlaHelpers::ConvertElementType(b, reduce, type); + auto softmax = + log_ + // softmax = shifted_logits - log(sum(exp(shifted_logits))) + ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim}) + // softmax = exp(shifted_logits) / sum(exp(shifted_logits)) + : b->Div(exp_shifted, sum, {kBatchDim}); ctx->SetOutput(0, softmax); } @@ -82,7 +81,6 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, const xla::ComputationDataHandle& logits, const xla::ComputationDataHandle& labels) { const xla::Computation& max_func = *ctx->GetOrCreateMax(type); - const xla::Computation& add_func = *ctx->GetOrCreateAdd(type); const int kBatchDim = 0; const int kClassDim = 1; @@ -100,8 +98,12 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, auto exp_shifted_logits = b->Exp(shifted_logits); // sum_{class} (exp(logits - max_logits)) - auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type), - add_func, {kClassDim}); + const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + auto converted = + XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type); + auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type); // log(sum(exp(logits - max_logits))) auto log_sum_exp = b->Log(sum_exp); @@ -110,9 +112,13 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes // (The subtraction broadcasts along the batch dimension.) - xla::ComputationDataHandle loss = b->Reduce( - b->Mul(b->Neg(labels), b->Sub(shifted_logits, log_sum_exp, {kBatchDim})), - XlaHelpers::Zero(b, type), add_func, {kClassDim}); + auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim}); + auto mul = b->Mul(b->Neg(labels), sub); + auto sum = + b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type), + XlaHelpers::Zero(b, accumulation_type), + *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); + auto loss = XlaHelpers::ConvertElementType(b, sum, type); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 79c435c90a1f57250be90c2c2523bf3d7d231461..43c15e753805352875034dfd2c70a2a1ed9a4114 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -111,27 +111,24 @@ class SplitVOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const int32 num_split = num_outputs(); + const TensorShape input_shape = ctx->InputShape(0); const TensorShape index_shape = ctx->InputShape(2); - xla::Literal literal_index; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index)); - int32 split_dim; - OP_REQUIRES(ctx, index_shape.dims() == 0, - errors::InvalidArgument("split_dim input to Split Op must be a " - "scalar")); - split_dim = literal_index.Get({}); + int64 split_dim_orig; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig)); + int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() + : split_dim_orig; + OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("-input rank(-", input_shape.dims(), + ") <= split_dim < input rank (", + input_shape.dims(), "), but got ", + split_dim_orig)); xla::ComputationDataHandle input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() > 0, errors::InvalidArgument("Can't split a 0 dimensional input")); - OP_REQUIRES( - ctx, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); - OP_REQUIRES( ctx, num_split > 0, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index b10880de77e6b9811008076cd4a959c284e558d1..5bb773d97fc5ce90dabceeefd5c29d916597f5ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -239,6 +239,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomUniform") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomUniformOp); @@ -272,6 +273,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { // TODO(phawkins): generalize to non-float, non-int32 seed types. REGISTER_XLA_OP(Name("StatelessRandomNormal") + .CompileTimeConstInput("shape") .TypeConstraint("dtype", DT_FLOAT) .TypeConstraint("Tseed", DT_INT32), StatelessRandomNormalOp); diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 91c169428c7a88a8d107a97445aeea999946e3e9..6204aa4e27000fddec7f5b82b2198d37956f6aba 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -77,13 +77,14 @@ class StridedSliceOp : public XlaOpKernel { for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { slice_begin.push_back(begin[i]); - slice_end.push_back(end[i]); + slice_end.push_back(std::max(end[i], begin[i])); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); - slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); + slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, + input_shape.dim_size(i) - begin[i] - 1)); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0c5ad9e5255ffc3dfcfb83335060ae833937b3ce..7cb47f908d4ff43f455f1e77c53cd3cc956579ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -60,11 +60,13 @@ XLAJIT_MAKE_UNARY( b->Add(XlaHelpers::One(b, input_type(0)), x)))); // acosh(x) = log(x + sqrt(x^2 - 1)) +// = log(x + sqrt((x+1)*(x-1))) XLAJIT_MAKE_UNARY( Acosh, - b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + b->Log(b->Add(x, + b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), + b->Sub(x, XlaHelpers::One(b, input_type(0)))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 488fda74bf7b5c1d66f8d706a1be3cc1fc29a492..fde1977c1b1834156b87b4fb3516f7bf8df435d7 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -39,6 +39,7 @@ cc_library( ":batch_dot", ":triangular_solve", ":util", + ":while_loop", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -126,6 +127,30 @@ cc_library( ], ) +xla_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop", srcs = ["while_loop.cc"], @@ -140,17 +165,3 @@ cc_library( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index e795701181dd80a2ff544743d513bffd52fd2399..203365e2ab07e0da1abfac5452a8ec41a4ddf406 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,68 +32,122 @@ namespace tensorflow { namespace { +// The Cholesky–Banachiewicz algorithm. See +// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms +// for a description. +// // def cholesky_unblocked(a): // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] // n = a.shape[-2] // l = np.zeros_like(a) // for j in xrange(n): -// r = l[..., j, :j] -// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r)) -// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], -// np.transpose(r))) / l[..., j, j] +// row = l[..., j, :j] +// row_t = np.swapaxes(row, -1, -2) +// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t)) +// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / +// l[..., j, j] // return l xla::StatusOr CholeskyUnblocked( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(a)); - xla::ComputationDataHandle l = Zeros(builder, *shape); - const int64 n = xla::ShapeUtil::GetDimension(*shape, -2); - for (int j = 0; j < n; ++j) { - // Picture of block structure: - // ... \ - // \ - // -- r -- d - // |\ - // B c \ - // | \ - // | ... - // - // ^ - // column j - TF_ASSIGN_OR_RETURN(auto d, - SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1})); - TF_ASSIGN_OR_RETURN(auto c, - SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1})); - xla::ComputationDataHandle new_d_squared = d; - xla::ComputationDataHandle br; - if (j > 0) { - TF_ASSIGN_OR_RETURN(auto r, - SliceInMinorDims(builder, l, {j, 0}, {j + 1, j})); - TF_ASSIGN_OR_RETURN(auto b, - SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); - TF_ASSIGN_OR_RETURN(auto r_squared, - BatchDot(builder, r, r, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - new_d_squared = builder->Sub(new_d_squared, r_squared); + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(*a_shape); + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape->dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); - TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, - /*transpose_y=*/true, - /*conjugate_x=*/false, - /*conjugate_y=*/false)); - } - auto new_d_inv = builder->Pow( - new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); - auto new_d = builder->Mul(new_d_inv, new_d_squared); - TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j})); + xla::ComputationDataHandle l = Zeros(builder, *a_shape); - if (j > 0) { - c = builder->Sub(c, br); + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); } - auto new_c = builder->Mul(c, new_d_inv); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j})); - } - return l; + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape->element_type()); + auto mask_zeros_row = Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape->element_type()); + auto mask_zeros_col = Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = body_builder->ConstantR1(mask_vector); + auto mask_range_row = body_builder->Broadcast( + body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = body_builder->Broadcast( + body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = body_builder->ConstantR0(0); + TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, + {i, zero}, {1, n})); + auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), + mask_zeros_row, l_i); + // a[..., i, i] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, + {i, i}, {1, 1})); + // np.dot(row, np.swapaxes(row, -1, -2)) + xla::ComputationDataHandle diag_dot; + TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, + /*transpose_x=*/false, + /*transpose_y=*/true)); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = body_builder->Pow( + body_builder->Sub(a_ii, diag_dot), + FloatLiteral(body_builder, a_shape->element_type(), 0.5)); + + // a[..., i+1:, i] + auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); + // select the whole i-th column, then mask out all rows above i+1 + TF_ASSIGN_OR_RETURN( + auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); + auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), + mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true)); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), + mask_zeros_col, dot); + + auto col_update = + body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); + TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( + body_builder, body_l, col_update, {i})); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( + body_builder, body_l, l_ii, {i, i})); + + return std::vector{body_a, body_l}; + }; + + TF_ASSIGN_OR_RETURN( + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; } } // namespace diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index e083a383be4be0d1b556b63214fe5f70323b4149..17da8d8b22d107701ce768ac945c1404df6d47e8 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. -// TODO(mattjj): handle the complex Hermitian case +// TODO(znado): handle the complex Hermitian case xla::StatusOr Cholesky( xla::ComputationBuilder* builder, xla::ComputationDataHandle a, int64 block_size = 256); diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 6009243f9774eea24e8049e2bd50fe32f291132f..45699233ea8b2a75e3850098250307b95546cc28 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -141,6 +141,8 @@ xla::StatusOr XlaScatter( body_builder->ConstantR0(true), xla::CreateScalarAndComputation(body_builder)); + // Make the index in bounds to prevent implementation defined behavior. + index = body_builder->Max(index, zero_index); index = body_builder->Pad( index, zero_index, xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); @@ -157,8 +159,8 @@ xla::StatusOr XlaScatter( auto update = body_builder->DynamicSlice(updates, updates_offset, flat_updates_slice_shape); - // Unflatten the major (iteration) dimensions of the slice to their original - // shape. + // Unflatten the major (iteration) dimensions of the slice to their + // original shape. std::vector updates_slice_shape(num_index_dims, 1); updates_slice_shape.insert(updates_slice_shape.end(), buffer_shape_post_axes.begin(), @@ -167,15 +169,16 @@ xla::StatusOr XlaScatter( // Apply the update to the buffer. If there is a combiner, use it to merge // the current values with the update. + auto current_value = + body_builder->DynamicSlice(buffer, index, updates_slice_shape); if (combiner) { - auto current_value = - body_builder->DynamicSlice(buffer, index, updates_slice_shape); update = combiner(current_value, update, body_builder); } - // Apply the update if it is in range. - buffer = body_builder->Select( - index_in_range, body_builder->DynamicUpdateSlice(buffer, update, index), - buffer); + // Use the current value instead of the update if the index is out of + // bounds. + update = body_builder->Select(index_in_range, update, current_value); + // Apply the update. + buffer = body_builder->DynamicUpdateSlice(buffer, update, index); return std::vector{indices, updates, buffer}; }; diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index f579669bbd852b514e021ce71d635f8ce5e4fe4d..31d823ca336039f691f2c16e37028c0de98b1ee5 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -140,13 +140,47 @@ xla::StatusOr SliceInMinorDims( return builder->Slice(x, padded_start, padded_end, strides); } +std::vector PrependMajorDims(xla::ComputationBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices) { + std::vector output(indices.size() + major_dims.size()); + std::copy(major_dims.begin(), major_dims.end(), output.begin()); + std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); + return output; +} + +xla::StatusOr DynamicSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts, + const gtl::ArraySlice& sizes) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); + return builder->DynamicSlice(x, padded_starts, padded_sizes); +} + xla::StatusOr UpdateSlice( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start) { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); - return builder->DynamicUpdateSlice( - x, update, builder->ConstantR1(start_as_int32)); + auto start_constant = builder->ConstantR1(start_as_int32); + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_ASSIGN_OR_RETURN(std::unique_ptr start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(*start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return builder->DynamicUpdateSlice(x, update, start_constant); } xla::StatusOr UpdateSliceInMinorDims( @@ -162,6 +196,29 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + return builder->DynamicUpdateSlice(x, update, padded_starts); +} + +xla::StatusOr PrependZerosInMajorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + auto zero = builder->Reshape(builder->ConstantR0(0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = + builder->Reshape(starts[i], {1}); + } + return builder->ConcatInDim(padded_starts, 0); +} + xla::StatusOr TransposeInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 51f8baaf00bd8fd25baa1a87be8cb0089dfb22b5..b684123f1363cff9e6ac4314cc3a8ae7630cbdf3 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -32,16 +32,39 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, double value); +// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros +// prepended until the array is length n_dims. +xla::ComputationDataHandle PrependZerosInMajorDims( + xla::ComputationBuilder* builder, + gtl::ArraySlice starts); + // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, int64 value); +// Builds a vector of zeros of length rank(x) with the last two values being +// those in `starts`. +xla::StatusOr PrependZerosInMajorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts); + // Performs a slice in the minor dimensions of a Tensor. xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, gtl::ArraySlice start, gtl::ArraySlice end); +// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. +std::vector PrependMajorDims(xla::ComputationBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices); + +// Performs a dynamic slice in the minor dimensions of a Tensor. +xla::StatusOr DynamicSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const std::vector& starts, + const gtl::ArraySlice& sizes); + // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update xla::StatusOr UpdateSlice( @@ -54,6 +77,11 @@ xla::StatusOr UpdateSliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, + const std::vector& starts); + // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::StatusOr TransposeInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b6bd33af2e42a4ab93a22528fd49ef53c46bb479 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using UtilTest = xla::ClientLibraryTestBase; +using UtilLeftLookingTest = xla::ClientLibraryTestBase; + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(UtilTest, Simple2dLookup) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, x, y; + auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); + auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); + auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); + TF_ASSERT_OK(result.status()); + + ComputeAndCompareR2(&builder, {{10}}, + {a_data.get(), x_data.get(), y_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(UtilTest, Simple3dLookup) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, 4})); + + ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, + {a_data.get(), index_data.get()}); +} + +XLA_TEST_F(UtilTest, SimpleSliceUpdate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b, x, y; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); + auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); + + auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected( + {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); + + ComputeAndCompareR2( + &builder, expected, + {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); +} + +XLA_TEST_F(UtilTest, RowBatchDot) { + xla::ComputationBuilder builder(client_, TestName()); + + int n = 4; + + xla::ComputationDataHandle a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, n})); + TF_ASSERT_OK_AND_ASSIGN( + auto dot, BatchDot(&builder, l_index, row, + /*transpose_x=*/false, /*transpose_y=*/true)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 86c02ac2e65c12d3527c4022df0cc603e522ef7a..495d9c60780b0a728e8dbfb4537d33d92b4bb5b7 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -54,7 +54,6 @@ xla::StatusOr> XlaWhileLoop( auto result, condition_function(unpack_tuple(parameter, arity, cond_builder.get()), cond_builder.get())); - TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result)); } TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index 98f72b3792eb147f5a1847c5e1ecef18bccbca5f..bb9168fa358154f3db9dab87bacc9bf28dd16406 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -7,17 +7,13 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( - name = "functional_ops", - srcs = ["functional_ops.cc"], - deps = [ - "//tensorflow/core:framework", + name = "xla_ops", + srcs = [ + "dynamic_slice_ops.cc", + "functional_ops.cc", + "reduce_window_op.cc", + "sendrecv_ops.cc", ], - alwayslink = 1, -) - -cc_library( - name = "sendrecv_ops", - srcs = ["sendrecv_ops.cc"], deps = [ "//tensorflow/core:framework", ], @@ -25,31 +21,9 @@ cc_library( ) tf_gen_op_wrapper_py( - name = "gen_functional_ops", - out = "gen_functional_ops.py", - deps = [ - ":functional_ops", - ], -) - -tf_gen_op_wrapper_py( - name = "gen_sendrecv_ops", - out = "gen_sendrecv_ops.py", + name = "gen_xla_ops", + out = "gen_xla_ops.py", deps = [ - ":sendrecv_ops", + ":xla_ops", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6c0edbb889b1751ac9d9d47d0c9534b543196ff --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9af982adc090ea78c711fd4656ba429c53b18c9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc index 4b41c16a8b3fdc0c3412c76d29d3ec2b7bdfd0aa..7ec7b50e905a6cbdecea4543dcb87322b5a7e844 100644 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc @@ -18,22 +18,24 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XLASend") +REGISTER_OP("XlaSend") .Input("tensor: T") .Attr("T: type") .Attr("tensor_name: string") .SetIsStateful() .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( -Sends the named tensor to another XLA computation. +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . tensor: The tensor to send. -tensor_name: The name of the tensor to send. +tensor_name: A string key that identifies the channel. )doc"); -REGISTER_OP("_XLARecv") - .Output("tensor: T") - .Attr("T: type") +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") .Attr("tensor_name: string") .Attr("shape: shape") .SetIsStateful() @@ -46,11 +48,14 @@ REGISTER_OP("_XLARecv") return Status::OK(); }) .Doc(R"doc( -Receives the named tensor from another XLA computation. +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . tensor: The tensor to receive. -tensor_name: The name of the tensor to receive. -shape: The shape of the input tensor. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. )doc"); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..42b6292f79ffddd155c05758a1420a2a583eb0c6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -0,0 +1,32 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//learning/tfx:__subpackages__", + "//tensorflow:internal", + ], +) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_py_clif_cc", +) + +tf_py_clif_cc( + name = "xla_op_registry", + srcs = ["xla_op_registry.clif"], + pyclif_deps = [ + "//tensorflow/core:framework/kernel_def_pyclif", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + ], +) + +py_library( + name = "xla", + srcs = ["xla.py"], + deps = [ + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + ], +) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ce65bec950fdfd38c3ca5bc62ac745ef8ca4a7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental library that exposes XLA operations directly in TensorFlow. + +It is sometimes useful to be able to build HLO programs directly from +TensorFlow. This file provides Tensorflow operators that map as closely as +possible to HLO operators. + +There is no promise of backward or forward compatibility for operators defined +in this module. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tf2xla.ops import gen_xla_ops + +# TODO(phawkins): provide wrappers for all XLA operators. + +dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice + + +def reduce_window(operand, + init, + reducer, + window_dimensions, + window_strides=None, + padding=None, + name=None): + """Wraps the XLA ReduceWindow operator. + + ReduceWindow is documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + + Args: + operand: the input tensor + init: a scalar tensor representing the initial value for the reduction + reducer: a reduction function that combines a pair of scalars. + window_dimensions: shape of the window, as a list of integers + window_strides: inter-window strides, as a list of integers. Optional; + if omitted, defaults to strides of 1. + padding: padding to apply to 'operand'. List of (low, high) pairs of + integers that specify the padding to apply before and after each + dimension. Optional; if omitted, defaults to no padding. + name: the operator name, or None. + Returns: + A tensor that represents the output of the reduce_window operator. + """ + window_strides = window_strides or [1] * len(window_dimensions) + padding = padding or [(0, 0)] * len(window_dimensions) + padding_low = [x for (x, _) in padding] + padding_high = [y for (_, y) in padding] + return gen_xla_ops.xla_reduce_window( + operand, + init, + reducer, + window_dimensions, + window_strides, + padding_low, + padding_high, + name=name) + + +recv = gen_xla_ops.xla_recv +send = gen_xla_ops.xla_send + +while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/python/xla_op_registry.clif b/tensorflow/compiler/tf2xla/python/xla_op_registry.clif new file mode 100644 index 0000000000000000000000000000000000000000..e1ee6cc656a314876fc1fabbebe1bee39a6b2831 --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/xla_op_registry.clif @@ -0,0 +1,7 @@ +from "third_party/tensorflow/core/framework/kernel_def_pyclif.h" import * # KernelDef + +from "third_party/tensorflow/compiler/tf2xla/xla_op_registry.h": + namespace `tensorflow`: + def `XlaOpRegistry::DeviceKernels` as + device_kernels(device: str, include_compilation_only_kernels: bool) -> + list diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 1a0e09758f7cc6714793300c6ece14093a8ad246..5759c72af301785f3ca1110b58eeb2fe7dead713 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -65,8 +66,8 @@ ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !StringPiece(parsed_device.type) - .contains(kDeviceSuffixReplicatedCore)) { + !str_util::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index a9978e697b091715ce120f0d18fdddd259e08b32..b813668a9edd3a704a9dca1eaa588c1eced6ac31 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -90,6 +90,11 @@ TEST(ConvertGraphDefToXla, Sum) { TF_EXPECT_OK(result_or.status()); std::unique_ptr result = std::move(result_or.ValueOrDie()); EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + + config.mutable_feed(0)->mutable_id()->set_output_index( + 123); /* invalid output_index */ + EXPECT_TRUE(errors::IsInvalidArgument( + ConvertGraphDefToXla(graph_def, config, client, &computation))); } } // namespace diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index f428a194328935fec1210ea96245344de859e611..7ec85aa3cdec622cae509f45c5ba7740222025f9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -151,8 +151,15 @@ Status AddPlaceholdersForFeeds( Status status; Node* feed_node = g.AddNode(gd.node(0), &status); TF_RETURN_IF_ERROR(status); - info.data_type = - BaseType(feed_node->output_type(info.feed->id().output_index())); + + if (info.feed->id().output_index() < feed_node->num_outputs()) { + info.data_type = + BaseType(feed_node->output_type(info.feed->id().output_index())); + } else { + return errors::InvalidArgument( + "Invalid output_index ", info.feed->id().output_index(), + " for feed node ", info.feed->id().node_name()); + } } } @@ -281,4 +288,13 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } +void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, + KernelDef* kdef) { + for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { + if (constraint.name() == name) { + constraint.mutable_allowed_values()->mutable_list()->add_type(dtype); + } + } +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index e5fba8ede7745febbb42c572a7b52247213afc95..745beb39c1d917cd0d1cd219536ee26a96253ec9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -51,6 +52,10 @@ string TensorIdToString(const tf2xla::TensorId& id); // edges are considered. Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); +// Add an allowed data type to the AttrConstraint with the given name. +void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, + KernelDef* kdef); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ed10d80609641b090cf78bf2e17364fe2fa89c31..ae51446204baf14dc03fc6305641048dbf3872b0 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -33,7 +34,7 @@ namespace { void ExpectErrorContains(const Status& status, StringPiece str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 15bba46ac62a97592656942afc767a303c9b97f3..86263d847ae02d50e70dafb0129b2664c522f2a3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -365,6 +365,13 @@ Status BuildComputation( return a->arg_num() < b->arg_num(); }); + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + for (const XlaResource* resource : arg_resources) { const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; @@ -412,6 +419,8 @@ Status BuildComputation( // Builds the XLA computation. builder->Tuple(elems); + builder->ClearOpMetadata(); + xla::StatusOr computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status(); @@ -514,6 +523,13 @@ Status XlaCompiler::BuildArguments( } } + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name("XLA_Args"); + builder->SetOpMetadata(arg_metadata); + // Build parameter handles for non-constant arguments. std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { @@ -552,6 +568,8 @@ Status XlaCompiler::BuildArguments( } } + builder->ClearOpMetadata(); + // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { @@ -582,6 +600,48 @@ Status XlaCompiler::BuildArguments( return Status::OK(); } +Status XlaCompiler::CompileSingleOp( + const XlaCompiler::CompileOptions& options, string const& name, + OpKernelContext* ctx, const std::vector& args, + CompilationResult* result) { + // TODO(b/74182462): We implement this by creating a new dummy Graph including + // _Arg nodes, and let CompileGraph walk it. This could be optimized. + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Status status; + // First create the actual node we care about computing. + Node* main_node = graph->AddNode(ctx->op_kernel().def(), &status); + TF_RETURN_IF_ERROR(status); + + // Create dummy _Arg nodes. Link these to `node` and also via a control + // dependency edge to the _SOURCE node. + for (int64 i = 0; i < ctx->num_inputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + Status status = NodeBuilder(name, "_Arg") + .ControlInput(graph->source_node()) + .Attr("T", ctx->input_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + graph->AddEdge(node, 0, main_node, i); + } + + // Similarly with return values, create dummy _Retval nodes fed by `node`. + for (int64 i = 0; i < ctx->num_outputs(); ++i) { + Node* node; + string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + Status status = NodeBuilder(name, "_Retval") + .Input(main_node, i) + .Attr("T", ctx->expected_output_dtype(i)) + .Attr("index", i) + .Finalize(graph.get(), &node); + TF_RETURN_IF_ERROR(status); + } + + return CompileGraph(options, name, std::move(graph), args, result); +} + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -656,6 +716,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); + // Copy the host transfer metadata to the result. + for (const auto& send : host_compute_sends_) { + *result->host_compute_metadata.add_device_to_host() = send.second; + } + for (const auto& recv : host_compute_recvs_) { + *result->host_compute_metadata.add_host_to_device() = recv.second; + } + // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); @@ -690,4 +758,59 @@ Status XlaCompiler::GetChannelHandle(const string& key, return Status::OK(); } +namespace { + +void SetTransfer(const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes, + tf2xla::HostTransferMetadata* transfer) { + transfer->set_key(key); + CHECK(types.size() == shapes.size()); + for (int i = 0; i < types.size(); ++i) { + tf2xla::TensorMetadata* metadata = transfer->add_metadata(); + metadata->set_type(types[i]); + shapes[i].AsProto(metadata->mutable_shape()); + } +} + +} // namespace + +Status XlaCompiler::SetDeviceToHostMetadata( + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { + if (host_compute_sends_.find(key) != host_compute_sends_.end()) { + return errors::InvalidArgument( + "Duplicate calls to SetDeviceToHostMetadata with key ", key); + } + tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; + SetTransfer(key, types, shapes, &transfer); + return Status::OK(); +} + +Status XlaCompiler::GetDeviceToHostShapes( + const string& key, std::vector* shapes) const { + const auto iter = host_compute_sends_.find(key); + if (iter == host_compute_sends_.end()) { + return errors::InvalidArgument( + "No host compute send shapes registered for key ", key); + } + shapes->clear(); + for (int i = 0; i < iter->second.metadata_size(); ++i) { + TensorShape shape(iter->second.metadata(i).shape()); + shapes->push_back(shape); + } + return Status::OK(); +} + +Status XlaCompiler::SetHostToDeviceMetadata( + const string& key, gtl::ArraySlice types, + gtl::ArraySlice shapes) { + if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { + return errors::InvalidArgument( + "Duplicate calls to SetHostToDeviceMetadata with key ", key); + } + tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; + SetTransfer(key, types, shapes, &transfer); + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c4449bc4be06daff856eff70c6d89be6ddbcf0ee..a6747bbe72e161b2ece55697825cce0e71145a5c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" @@ -216,6 +217,10 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector outputs; + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla::HostComputeMetadata host_compute_metadata; + // Resources whose values were updated by the computation, ordered // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. @@ -284,6 +289,14 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + // Compiles a single Op, given by an OpKernelContext, into an + // xla::Computation. Similar to CompileFunction but takes a single Op as + // input. + Status CompileSingleOp(const CompileOptions& options, string const& name, + OpKernelContext* ctx, + const std::vector& args, + CompilationResult* result); + // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. @@ -296,6 +309,22 @@ class XlaCompiler { // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + // Sets the shapes and types for the device to host transfer associated with + // 'key'. + Status SetDeviceToHostMetadata(const string& key, + gtl::ArraySlice types, + gtl::ArraySlice shapes); + + // Gets the shapes the device to host transfer associated with 'key'. + Status GetDeviceToHostShapes(const string& key, + std::vector* shapes) const; + + // Sets the shapes and types for the host to device transfer associated with + // 'key'. + Status SetHostToDeviceMetadata(const string& key, + gtl::ArraySlice types, + gtl::ArraySlice shapes); + const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } @@ -359,6 +388,9 @@ class XlaCompiler { std::unordered_map channels_; + std::unordered_map host_compute_sends_; + std::unordered_map host_compute_recvs_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index a18eeacd41808884fac9ec5d617cb0d274ea27d8..096dc7160bfc0a3a751f33e7d646471ebea56070 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -257,10 +258,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - StringPiece(status.error_message()).contains("depends on a parameter")) + str_util::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - StringPiece(status.error_message()).contains("[[Node: C = Reshape")) + str_util::StrContains(status.error_message(), "[[Node: C = Reshape")) << status.error_message(); } @@ -597,7 +598,8 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "is not defined.")) << status.error_message(); } @@ -676,11 +678,12 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE( - StringPiece(status.error_message()).contains("Attr T is not found")) + EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), + "Attr T is not found")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index 8286480e0ea07429adbe31ec4f16d043e321df0a..ead229aaccc292d4944db0c1eaf98c82583533cd 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -30,6 +31,12 @@ bool CpuOpFilter(KernelDef* kdef) { DT_FLOAT); return true; } + if (kdef->op() == "Const") { + AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); + } + if (kdef->op() == "Assert") { + AddDtypeToKernalDefConstraint("T", DT_STRING, kdef); + } return true; } diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index 8ca757e72355d890c13b8b448d35c327d3986696..62168b648331844bfe2db1a4d5dcad895c8726f3 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" @@ -25,6 +26,12 @@ bool GpuOpFilter(KernelDef* kdef) { kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; } + if (kdef->op() == "Const") { + AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); + } + if (kdef->op() == "Assert") { + AddDtypeToKernalDefConstraint("T", DT_STRING, kdef); + } return true; } diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f048662953e20b2a612271e2daeef6e370c4822a..62a5114837e07f35134ad99e28880d6a9233a213 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -121,6 +122,9 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_HALF: + return b->ConstantR0( + static_cast(Eigen::NumTraits::epsilon())); case DT_BFLOAT16: return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: @@ -273,4 +277,20 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, return Status::OK(); } +DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { + if (dtype == DT_BFLOAT16) { + return DT_FLOAT; + } + return dtype; +} + +xla::ComputationDataHandle XlaHelpers::ConvertElementType( + xla::ComputationBuilder* const builder, + const xla::ComputationDataHandle& operand, + const DataType new_element_type) { + xla::PrimitiveType convert_to; + TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); + return builder->ConvertElementType(operand, convert_to); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 2a027db4c839c917f3a7acd27184792d157356bf..68ab93b64a5fa87ad99e0f44d84f6473fc8bbebd 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -107,6 +107,18 @@ class XlaHelpers { const xla::ComputationDataHandle& on_value, const xla::ComputationDataHandle& off_value, xla::ComputationDataHandle* one_hot); + + // Certain DataTypes should use increased precision DataTypes when performing + // reductions. This function remaps a given DataType to a higher precision + // DataType if needed. + static DataType SumAccumulationType(const DataType& dtype); + + // A helper for creating a ConvertElementType xla op given a DataType rather + // than the xla::PrimitiveType. + static xla::ComputationDataHandle ConvertElementType( + xla::ComputationBuilder* const builder, + const xla::ComputationDataHandle& operand, + const DataType new_element_type); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 0dde6a986c61bdd5b0b2e6d7a16b29ab95be98ab..bbe808595d958346bd55bf8419306bf3de4cd1d0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -255,6 +255,8 @@ void XlaOpRegistry::RegisterCompilationKernels() { std::vector XlaOpRegistry::DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels) { + // Ensure compilation kernels registered. + RegisterCompilationKernels(); std::vector kernels; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index ff7453194af3a85bded86a5ce298f8779422dccb..e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,13 +51,13 @@ constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 34e733bc8d80b364cec1783006eba0a5468b55ea..751777222fcc7ec073958349aa2677d5b4e6757d 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -52,6 +52,7 @@ xla_proto_library( visibility = ["//visibility:public"], deps = [ ":xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:session_proto", ], ) @@ -372,7 +373,6 @@ tf_cc_test( cc_library( name = "array2d", - srcs = ["array2d.cc"], hdrs = ["array2d.h"], visibility = ["//visibility:public"], deps = [ @@ -654,18 +654,6 @@ tf_cc_test( # ----------------------------------------------------------------------------- -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. cc_header_only_library( name = "xla_headers_lib", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 46ee4e64c9ae7ca111d9d04bedcb74ff02a42386..ea75ad32d5df7bbadd37e89de6144b264ab6d5d1 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -121,10 +122,31 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 2D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 1D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && + std::is_same::value>::type> + Array(std::initializer_list values) + : Array(ToInt64Vector({values.size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + values_[idx] = static_cast(it1); + ++idx; + } + CHECK(idx == num_elements()); + } + + // Creates a 2D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list> values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { @@ -155,10 +177,13 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 3D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 3D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list>> values) @@ -196,10 +221,13 @@ class Array { CHECK(idx == num_elements()); } - // Creates a 4D array of Eigen::half from the given nested initializer list of - // float values. + // Creates a 4D array of a floating-point type (half, bfloat16, float, + // or double) from an initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array(std::initializer_list< std::initializer_list>>> diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 41f563486d21e42e88dcf6c751ce4a64da5e3213..a17e81f44832f272fd93dce9f854042b4a84fde4 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -52,10 +53,13 @@ class Array2D : public Array { Array2D(std::initializer_list> values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array2D(std::initializer_list> values) : Array(values) {} @@ -94,9 +98,23 @@ class Array2D : public Array { // Returns a linspace-populated Array2D in the range [from, to] (inclusive) // with dimensions n1 x n2. -std::unique_ptr> MakeLinspaceArray2D(float from, float to, - int64 n1, int64 n2); - +template +std::unique_ptr> MakeLinspaceArray2D(double from, double to, + int64 n1, int64 n2) { + auto array = MakeUnique>(n1, n2); + int64 count = n1 * n2; + NativeT step = + static_cast((count > 1) ? (to - from) / (count - 1) : 0); + auto set = [&array, n1, n2](int64 index, NativeT value) { + (*array)(index / n2, index % n2) = value; + }; + for (int64 i = 0; i < count - 1; ++i) { + set(i, (static_cast(from) + + static_cast(i) * static_cast(step))); + } + set(count - 1, static_cast(to)); + return array; +} } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_ARRAY2D_H_ diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index e5eb235d45d160d486d1499db665ed14a8509043..0e9a0722ae43e1dc6ecddde9cbc3daf1db058840 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -57,10 +57,13 @@ class Array3D : public Array { values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array3D( std::initializer_list>> diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index cff70e54bad0116bdd08674b626b3bf99dc89e1f..a75fffc605aa0df3e1e2eeb6d3129718cbbba0e4 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -82,10 +82,13 @@ class Array4D : public Array { values) : Array(values) {} - // Creates an array of Eigen::half from the given nested initializer list of - // float values. + // Creates an array of a floating-point type (half, bfloat16, float, + // or double) from the given nested initializer list of float values. template ::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) && std::is_same::value>::type> Array4D(std::initializer_list>>> diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 02356699a25e47be50eb15872df4c9c302fc289b..a299c2afd45aa6b785964b8a8e1400ddf54083a4 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -74,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", @@ -213,17 +214,3 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d15ccb0c28522c647617153aaa8e738d029dfaba..f0f94298a05f7c4bdc41cbfb8572454fbedd371d 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -177,6 +177,50 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } +StatusOr> Client::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr data, + Execute(computation, arguments, execution_options, execution_profile)); + + const Shape* shape_with_output_layout = nullptr; + if (execution_options && execution_options->has_shape_with_output_layout()) { + shape_with_output_layout = &execution_options->shape_with_output_layout(); + } + return Transfer(*data, shape_with_output_layout); +} + +StatusOr> Client::ComputeConstant( + const XlaComputation& computation, const Layout* output_layout) const { + ComputeConstantGraphRequest request; + *request.mutable_computation() = computation.proto(); + if (output_layout != nullptr) { + *request.mutable_output_layout() = *output_layout; + } + + ComputeConstantResponse response; + + VLOG(2) << "making compute-constant-graph request"; + Status s = stub_->ComputeConstantGraph(&request, &response); + VLOG(2) << "done with request"; + + if (!s.ok()) { + return s; + } + + VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return InternalError( + "no computed literal in the provided response in ComputeConstantGraph " + "request"); + } + return Literal::CreateFromProto(response.literal()); +} + StatusOr Client::LoadSnapshot(const SessionModule& module) { LoadComputationSnapshotRequest request; *request.mutable_module() = module; @@ -231,6 +275,46 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } +StatusOr> Client::Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options, + ExecutionProfile* execution_profile) { + ExecuteGraphRequest request; + *request.mutable_computation() = computation.proto(); + + if (execution_options == nullptr) { + *request.mutable_execution_options() = CreateDefaultExecutionOptions(); + } else { + *request.mutable_execution_options() = *execution_options; + } + for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; + *request.add_arguments() = argument->handle(); + } + + ExecuteResponse response; + VLOG(1) << "making execute request: " << request.ShortDebugString(); + Status s = stub_->ExecuteGraph(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + if (execution_profile != nullptr) { + *execution_profile = response.profile(); + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN( + auto execution_stats, + ExecutionStatsAsString(computation, response.profile())); + VLOG(1) << execution_stats; + } + } + + return MakeUnique(stub_, response.output()); +} + StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteParallelRequest request; @@ -266,6 +350,42 @@ StatusOr>> Client::ExecuteParallel( return std::move(outputs); } +StatusOr>> Client::ExecuteParallel( + tensorflow::gtl::ArraySlice computations) { + ExecuteGraphParallelRequest request; + + for (const XlaComputationInstance& computation : computations) { + ExecuteGraphRequest single_request; + *single_request.mutable_computation() = computation.computation.proto(); + for (GlobalData* argument : computation.arguments) { + *single_request.add_arguments() = argument->handle(); + } + *single_request.mutable_execution_options() = computation.execution_options; + *request.add_requests() = single_request; + } + + ExecuteParallelResponse response; + VLOG(1) << "making execute-graph-parallel request: " + << request.ShortDebugString(); + tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + + std::vector> outputs; + for (size_t i = 0; i < computations.size(); ++i) { + outputs.push_back( + MakeUnique(stub_, response.responses(i).output())); + if (computations[i].execution_profile != nullptr) { + *computations[i].execution_profile = response.responses(i).profile(); + } + } + + return std::move(outputs); +} + StatusOr> Client::GetDeviceHandles( int64 device_count) { if (device_count < 1) { @@ -342,6 +462,27 @@ StatusOr Client::GetComputationStats( return response.stats(); } +StatusOr Client::GetComputationStats( + const XlaComputation& computation, + const DebugOptions& debug_options) const { + ComputationGraphStatsRequest request; + + // TODO(b/74197823): Find a way to avoid the copy of the hlo proto. + *request.mutable_computation() = computation.proto(); + *request.mutable_debug_options() = debug_options; + ComputationStatsResponse response; + + VLOG(1) << "making computation graph stats request"; + Status s = stub_->GetComputationGraphStats(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + CHECK(response.has_stats()); + return response.stats(); +} + StatusOr> Client::GetComputationShape( const Computation& computation) { GetComputationShapeRequest request; @@ -359,6 +500,12 @@ StatusOr> Client::GetComputationShape( return WrapUnique(response.release_program_shape()); } +StatusOr> Client::GetComputationShape( + const XlaComputation& computation) { + TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); + return MakeUnique(result); +} + StatusOr Client::GetShape(const GlobalData& data) { GetShapeRequest request; *request.mutable_data() = data.handle(); @@ -397,6 +544,28 @@ StatusOr Client::ExecutionStatsAsString( return string("[Execution Statistics] not available."); } +StatusOr Client::ExecutionStatsAsString( + const XlaComputation& computation, const ExecutionProfile& profile) { + TF_ASSIGN_OR_RETURN( + auto computation_stats, + GetComputationStats(computation, + legacy_flags::GetDebugOptionsFromFlags())); + int64 total_flops = + computation_stats.flop_count() + computation_stats.transcendental_count(); + if (profile.compute_time_ns() > 0) { + int64 nanoseconds = profile.compute_time_ns(); + int64 cycle_count = profile.compute_cycle_count(); + double gflops = total_flops / nanoseconds; + return tensorflow::strings::StrCat( + "[Execution Statistics] flop count: ", computation_stats.flop_count(), + ", transcendental count: ", computation_stats.transcendental_count(), + ", compute execution time: ", nanoseconds, " nsec", + ", compute cycles: ", cycle_count, ", performance: ", gflops, + "gflop/s"); + } + return string("[Execution Statistics] not available."); +} + StatusOr Client::CreateChannelHandle() { CreateChannelHandleRequest request; CreateChannelHandleResponse response; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index c28380b689c7a0e16bf0bcbf15003f4aa15e42a7..14c685d94ea31c382d84223ca4e2eba544420d78 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service_interface.h" @@ -57,6 +58,21 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and returns the global + // data that was produced from the execution. + // * If execution_options is not nullptr, these options are passed to the + // service to affect how it compiles our computation. (The pointer does not + // need to live beyond this call.) + // * If execution_profile is not nullptr then the pointed-to ExecutionProfile + // will be filled with profile data from the execution. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> Execute( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + // A struct to represent a computation instance to be executed. // * If execution_options.device_handles is not empty, the computation is // executed on the devices associated with the handles by partitioning the @@ -83,6 +99,36 @@ class Client { StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); + // A struct to represent a computation instance to be executed. + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + struct XlaComputationInstance { + const XlaComputation& computation; + std::vector arguments; + ExecutionOptions execution_options; + ExecutionProfile* execution_profile; + + XlaComputationInstance(const XlaComputation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} + }; + + // Executes a list XlaComputationInstances and returns global data produced + // from each computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr>> ExecuteParallel( + tensorflow::gtl::ArraySlice computations); + // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations // (see ExecuteParallel) or to transfer data (see TransferToServer or @@ -137,6 +183,38 @@ class Client { const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); + // Executes the computation with the given arguments and transfers the result + // to the client as a literal. Parameters are defined the same as for + // Execute() and Transfer(). + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions* execution_options = nullptr, + ExecutionProfile* execution_profile = nullptr); + + // Computes the value of the given computation using a non-optimized + // interpreter on the host. + // + // The computation must not depend on any parameters, or on stateful operators + // such as `RngNormal` or `Infeed`. + // + // This functionality can be useful when translating a computation into XLA + // where something that looked dynamic is required by XLA to be specified as a + // constant. E.g. the source computation (outside of XLA) may include a + // dynamic computation of the shape of something and ComputeConstant lets you + // determine what the value of that computation is in the case where the value + // can be determined at compile time. + // + // If output_layout is non-null, then the output of the computation will be + // stored using that layout. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> ComputeConstant( + const XlaComputation& computation, + const Layout* output_layout = nullptr) const; + // Unregister the memory for the given GlobalData on the device. Status Unregister(const GlobalData& data); @@ -148,6 +226,13 @@ class Client { StatusOr GetComputationStats( const Computation& computation, const DebugOptions& debug_options) const; + // Retrieves the statistics of the given computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr GetComputationStats( + const XlaComputation& computation, + const DebugOptions& debug_options) const; + // Returns the Shape of the given array specified by 'data'. The shape // includes the Layout of the array as it is stored on the service. StatusOr GetShape(const GlobalData& data); @@ -157,6 +242,13 @@ class Client { StatusOr> GetComputationShape( const Computation& computation); + // As above, but returns the shape of the provided computation (parameter + // types/names and return type). + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> GetComputationShape( + const XlaComputation& computation); + // Creates a channel handle that can be used to transfer data between // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); @@ -170,6 +262,8 @@ class Client { // ExecutionProfile returned from an execution of the computation. StatusOr ExecutionStatsAsString(const Computation& computation, const ExecutionProfile& profile); + StatusOr ExecutionStatsAsString(const XlaComputation& computation, + const ExecutionProfile& profile); ServiceInterface* stub_; // Stub that this client is connected on. diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index c7e2c4367b89ca2112022fa40449ae3ebe28463e..59662c95ac15e7c23790c5b5ff5d75a694613aeb 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -39,16 +39,15 @@ CompileOnlyClient::CompileAheadOfTime( return compiler_service_->CompileAheadOfTime(service_instances, options); } -int64 CompileOnlyClient::PointerSizeForTriple( - tensorflow::StringPiece target_triple) { - llvm::Triple triple(llvm::Triple::normalize( - llvm::StringRef(target_triple.data(), target_triple.size()))); - if (triple.isArch64Bit()) { +int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { + llvm::Triple llvm_triple( + llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); + if (llvm_triple.isArch64Bit()) { return 8; - } else if (triple.isArch32Bit()) { + } else if (llvm_triple.isArch32Bit()) { return 4; } else { - CHECK(triple.isArch16Bit()); + CHECK(llvm_triple.isArch16Bit()); return 2; } } diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index b1dcad6a49a270935b07e26de2d3945b912359d1..4d3b0ee0d6e9ba82cfa09af0fbff0ae1efa0ac64 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -253,26 +253,6 @@ StatusOr ComputationBuilder::GetProgramShape() { return std::move(*response.mutable_program_shape()); } -ComputationDataHandle ComputationBuilder::CheckShape( - const ComputationDataHandle& operand, const Shape& expected_shape) { - std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); - CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) - << "want " << ShapeUtil::HumanString(expected_shape) << " got " - << ShapeUtil::HumanString(*actual_shape); - return operand; -} - -void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { - std::unique_ptr lhs_shape = GetShape(lhs).ConsumeValueOrDie(); - std::unique_ptr rhs_shape = GetShape(rhs).ConsumeValueOrDie(); - VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals " - << ShapeUtil::HumanString(*rhs_shape); - CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape)) - << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs " - << ShapeUtil::HumanString(*rhs_shape); -} - ComputationDataHandle ComputationBuilder::Slice( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice start_indices, @@ -408,7 +388,7 @@ ComputationDataHandle ComputationBuilder::Reshape( ComputationDataHandle ComputationBuilder::Collapse( const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dims_to_collapse) { + tensorflow::gtl::ArraySlice dimensions) { if (!first_error_.ok()) { return ComputationDataHandle(); } @@ -416,8 +396,8 @@ ComputationDataHandle ComputationBuilder::Collapse( // Don't support out-of-order collapse here. // Checks that the collapsed dimensions are in order and consecutive. for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dims_to_collapse.size(); ++i) { - if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { + i < dimensions.size(); ++i) { + if (dimensions[i] - 1 != dimensions[i - 1]) { NoteError(InvalidArgument( "Collapsed dimensions are not in order and consecutive.")); return ComputationDataHandle(); @@ -434,9 +414,9 @@ ComputationDataHandle ComputationBuilder::Collapse( VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dims_to_collapse, ","); + << tensorflow::str_util::Join(dimensions, ","); - if (dims_to_collapse.size() <= 1) { + if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus // enqueueing a trivial reshape. return operand; @@ -444,7 +424,7 @@ ComputationDataHandle ComputationBuilder::Collapse( std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { + if (i <= dimensions.front() || i > dimensions.back()) { new_sizes.push_back(original_shape->dimensions(i)); } else { new_sizes.back() *= original_shape->dimensions(i); @@ -753,13 +733,13 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, } void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape, + const Shape& shape_with_layout, const string& outfeed_config) { OpRequest op_request; OutfeedRequest* request = op_request.mutable_outfeed_request(); request->set_outfeed_config(outfeed_config); *request->mutable_operand() = operand; - *request->mutable_shape() = shape; + *request->mutable_shape() = shape_with_layout; RunOpAndNoteError(&op_request); } @@ -789,6 +769,20 @@ ComputationDataHandle ComputationBuilder::CustomCall( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::HostCompute( + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { + OpRequest op_request; + HostComputeRequest* request = op_request.mutable_host_compute_request(); + for (const ComputationDataHandle& operand : operands) { + *request->add_operands() = operand; + } + *request->mutable_shape() = shape; + request->set_channel_name(channel_name); + request->set_cost_estimate_ns(cost_estimate_ns); + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Complex( const ComputationDataHandle& real, const ComputationDataHandle& imag, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -854,6 +848,14 @@ ComputationDataHandle ComputationBuilder::Or( return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); } +// TODO(b/65209188): Create a dedicated lowering for Xor +ComputationDataHandle ComputationBuilder::Xor( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return Or(And(Not(lhs), rhs, broadcast_dimensions), + And(lhs, Not(rhs), broadcast_dimensions)); +} + ComputationDataHandle ComputationBuilder::Not( const ComputationDataHandle& operand) { return UnaryOp(UNOP_NOT, operand); @@ -1220,6 +1222,22 @@ ComputationDataHandle ComputationBuilder::While( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::Gather( + const ComputationDataHandle& input, + const ComputationDataHandle& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + OpRequest op_request; + GatherRequest* gather_request = op_request.mutable_gather_request(); + *gather_request->mutable_input() = input; + *gather_request->mutable_gather_indices() = gather_indices; + *gather_request->mutable_dimension_numbers() = dimension_numbers; + for (int64 window_bound : window_bounds) { + gather_request->add_window_bounds(window_bound); + } + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Conditional( const ComputationDataHandle& predicate, const ComputationDataHandle& true_operand, @@ -1352,15 +1370,16 @@ ComputationDataHandle ComputationBuilder::BatchNormInference( ComputationDataHandle ComputationBuilder::BatchNormGrad( const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& mean, const ComputationDataHandle& var, + const ComputationDataHandle& batch_mean, + const ComputationDataHandle& batch_var, const ComputationDataHandle& grad_output, float epsilon, int64 feature_index) { OpRequest op_request; BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); *request->mutable_operand() = operand; *request->mutable_scale() = scale; - *request->mutable_mean() = mean; - *request->mutable_variance() = var; + *request->mutable_mean() = batch_mean; + *request->mutable_variance() = batch_var; *request->mutable_grad_output() = grad_output; request->set_epsilon(epsilon); request->set_feature_index(feature_index); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 7cae91e9e04bba8f28f2348c552a941e4f7a36b4..019c6f3afb5d57bfe453988ded19120a4483cf36 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -104,15 +104,6 @@ class ComputationBuilder { // Retrieves the (inferred) result for the current computation's shape. StatusOr GetProgramShape(); - // Checks that the operand has the given expected shape. Returns the operand - // if yes, fails with a CHECK error if no. - ComputationDataHandle CheckShape(const ComputationDataHandle& operand, - const Shape& expected_shape); - - // Checks that the lhs and rhs results have the same shape. - void CheckSameShape(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - // Enqueues a constant with the value of the given literal onto the // computation. ComputationDataHandle ConstantLiteral(const Literal& literal); @@ -198,9 +189,8 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice new_sizes); // Enqueues an operation onto the computation that collapses the operand, from - // minor to major order, then reshapes it into the shape with the given - // dimension sizes, also from major to minor. Conceptually, this is a limited - // form of "shape casting". + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". ComputationDataHandle Reshape(const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice new_sizes); @@ -446,6 +436,16 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice operands, const Shape& shape); + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + ComputationDataHandle HostCompute( + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given @@ -503,6 +503,10 @@ class ComputationBuilder { const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle Xor( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle Not(const ComputationDataHandle& operand); ComputationDataHandle ShiftLeft( @@ -708,6 +712,13 @@ class ComputationBuilder { const int exponent_bits, const int mantissa_bits); + // Enqueues a Gather node onto the computation. + ComputationDataHandle Gather( + const ComputationDataHandle& input, + const ComputationDataHandle& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + // Enqueues a Send node onto the computation, to send the given operand to // a Recv instruction that shares the same channel handle. void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); @@ -856,7 +867,7 @@ class ComputationBuilder { Window* window); // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation binop, + ComputationDataHandle UnaryOp(UnaryOperation unop, const ComputationDataHandle& operand); // Internal helper method that does the building for an arbitrary binary op. diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 804e34f5e75ce2d153ac7627b94a543fda88e810..6e3c5cb484b8f1ef053fa287a4d462aeb886e530 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -76,4 +76,35 @@ ExecutableBuildOptions::generate_hlo_graph() const { return generate_hlo_graph_; } +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_optimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { + return dump_optimized_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { + return dump_per_pass_hlo_proto_to_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { + hlo_profile_ = enabled; + return *this; +} + +tensorflow::gtl::optional ExecutableBuildOptions::hlo_profile() const { + return hlo_profile_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 3a52dbac9adb155ad9a7d91a8102707f70fe2fbf..11f10983606fe02b1edb11a260edde8e5f9a726f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -57,15 +58,36 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_generate_hlo_graph(string regex); const tensorflow::gtl::optional& generate_hlo_graph() const; + // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs + // to (as in DebugOptions). + ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_per_pass_hlo_proto_to() const; + + // If true, specifies that we should record an HLO profile during execution + // and log it after execution (as in DebugOptions). If nullopt the default is + // used. + ExecutableBuildOptions& set_hlo_profile(bool enabled); + tensorflow::gtl::optional hlo_profile() const; + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; private: + tensorflow::gtl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; + tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index fca2bf2688cd21b44f099da3bae3b890cbb069ab..59c4a53c05a45490a7c8e732840a4e70767c46c2 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -24,6 +24,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,21 +46,8 @@ cc_library( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 24048a1e5a782661ba577ba50e3b5b2914f17c0a..63df449e0b3bdd642d548319dd7d621ca2f59b1d 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -26,6 +26,7 @@ limitations under the License. namespace xla { namespace { + using InstructionGenerator = ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, const ComputationDataHandle&); @@ -47,6 +48,27 @@ Computation CreateScalarComputation(const string& name, PrimitiveType type, generator(b.get(), lhs, rhs); return b->BuildAndNoteError(); } + +using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); + +XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, + XlaBuilder* builder, + XlaOpGenerator generator) { + std::unique_ptr b; + if (type == PRED) { + b = builder->CreateSubBuilder(name); + } else { + b = builder->CreateSubBuilder( + tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + } + + const Shape scalar = ShapeUtil::MakeShape(type, {}); + auto lhs = b->Parameter(0, scalar, "lhs"); + auto rhs = b->Parameter(1, scalar, "rhs"); + generator(b.get(), lhs, rhs); + return b->BuildAndNoteError(); +} + } // namespace Computation CreateScalarAddComputation(PrimitiveType type, @@ -60,7 +82,7 @@ Computation CreateScalarAddComputation(PrimitiveType type, Computation CreateScalarMultiplyComputation(PrimitiveType type, ComputationBuilder* builder) { return CreateScalarComputation( - "add", type, builder, + "mul", type, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); } @@ -114,4 +136,75 @@ StatusOr Any(const ComputationDataHandle& predicates, return builder->Reduce(predicates, f, logical_or, all_dimensions); } +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "add", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Add(lhs, rhs); + }); +} + +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "mul", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Mul(lhs, rhs); + }); +} + +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "ge", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Ge(lhs, rhs); + }); +} + +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "max", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Max(lhs, rhs); + }); +} + +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder) { + return CreateScalarComputation( + "min", type, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Min(lhs, rhs); + }); +} + +XlaComputation CreateScalarAndComputation(XlaBuilder* builder) { + return CreateScalarComputation( + "and", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->And(lhs, rhs); + }); +} + +XlaComputation CreateScalarOrComputation(XlaBuilder* builder) { + return CreateScalarComputation( + "or", PRED, builder, + [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) { + return b->Or(lhs, rhs); + }); +} + +StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { + auto f = builder->ConstantR0(false); + XlaComputation logical_or = CreateScalarOrComputation(builder); + TF_ASSIGN_OR_RETURN(const Shape& predicates_shape, + builder->GetShape(predicates)); + std::vector all_dimensions(ShapeUtil::Rank(predicates_shape)); + std::iota(all_dimensions.begin(), all_dimensions.end(), 0); + return builder->Reduce(predicates, f, logical_or, all_dimensions); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index ae89784bc227d837cf15f0a89687dd00dccc2745..f4d3fc801590fedbb84ed3d6283e62f47c56d5c7 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -56,6 +58,48 @@ Computation CreateScalarOrComputation(ComputationBuilder* builder); StatusOr Any(const ComputationDataHandle& predicates, ComputationBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar add computation and returns it. +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar multiply computation and returns it. +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar ge computation and returns it. +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar max computation and returns it. +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar min computation and returns it. +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder); +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar logical AND computation and returns it. +XlaComputation CreateScalarAndComputation(XlaBuilder* builder); + +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Creates a scalar logical OR computation and returns it. +XlaComputation CreateScalarOrComputation(XlaBuilder* builder); + +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +// +// Returns whether any predicate in "predicates" is set. +// +// Note: if predicates is zero-sized, Any() vacuously returns false. +StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index b63a1465ea755b906853860d47768ecbeaa0dcdd..311dc4bdd72cfd7999e83a26e11614d6ca005bce 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -111,4 +111,20 @@ std::vector> MakeFakeArgumentsOrDie( return fake_arguments; } +std::vector> MakeFakeArgumentsOrDie( + const XlaComputation& computation, Client* client) { + CHECK(computation.proto().has_program_shape()) + << "Computation should have progran shape."; + auto program_shape = computation.proto().program_shape(); + + // For every (unbound) parameter that the computation wants, we manufacture + // some arbitrary data so that we can invoke the computation. + std::vector> fake_arguments; + for (const Shape& parameter : program_shape.parameters()) { + fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); + } + + return fake_arguments; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 7e640d1307edcc3e2c021f4391c456f578a015ee..1dc2622972d5fd3da6991d70b800cc3fd5a638f4 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -38,6 +39,12 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, std::vector> MakeFakeArgumentsOrDie( const Computation& computation, Client* client); +// Returns vector of GlobalData handles of fake data (created using +// MakeFakeDataOrDie) that are correctly shaped arguments for the given +// xla computation. +std::vector> MakeFakeArgumentsOrDie( + const XlaComputation& computation, Client* client); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 91396f055fe4a3ecbd436139be9470e2a35e1c63..30594243dcf51d2b5312b9dcb2bea7d0cd78524d 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -265,6 +265,24 @@ StatusOr> LocalClient::Compile( updated_options)); } +StatusOr> LocalClient::Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options) { + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + local_service_->CompileExecutable( + computation, argument_layouts, updated_options)); + return WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); +} + StatusOr> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index b52a30f5a0b92e0094e6b0de3241c10a5a909cad..98ee7c62c94be7c618cedd3dc12ecbfc812ee180 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -69,7 +69,7 @@ class LocalExecutable { // of the computation. tensorflow::Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, const Backend& backend); + const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. @@ -123,6 +123,15 @@ class LocalClient : public Client { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); + // Build and return a LocalExecutable object. The executable is compiled using + // the given XlaComputation, argument layouts and options. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> Compile( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& options); + // Copy the literal data to the device with the given ordinal and return as a // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..31fa1241ee474a31575c45cf7652063dfc818fac --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -0,0 +1,79 @@ +# Description: +# The new XLA client libraries. +# +# This is NOT YET ready to use. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = [":friends"]) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:lib", + ], +) + +# TODO(b/74197823): Replace computation_builder with xla_builder. +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + deps = [ + ":xla_computation", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:shape_inference", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":xla_builder", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ccdc2ded2c099690bc9187936db6491ef4142dd --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -0,0 +1,1963 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { + +using tensorflow::strings::StrCat; + +namespace { + +int64 GetUniqueId() { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static int64 built_counter = 0; + tensorflow::mutex_lock loc(mu); + const int64 id = built_counter++; + return id; +} + +// Returns true if an instruction with the given opcode can be the root of the +// computation. +bool CanBeRoot(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kOutfeed: + case HloOpcode::kTrace: + return false; + default: + return true; + } +} + +StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); + operand_shapes.push_back(shape); + } + return operand_shapes; +} + +} // namespace + +StatusOr XlaBuilder::GetShape(const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + + TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op)); + return instr->shape(); +} + +StatusOr XlaOp::GetShape() const { + if (builder_ == nullptr) { + return InvalidArgument( + "cannot GetShape for an invalid XlaOp with handle %lld", handle()); + } + return builder_->GetShape(*this); +} + +XlaBuilder::XlaBuilder(const string& computation_name) + : name_(computation_name) {} + +XlaBuilder::~XlaBuilder() {} + +void XlaBuilder::NoteError(const Status& error) { + CHECK(!error.ok()); + if (die_immediately_on_error_) { + LOG(FATAL) << "error building computation: " << error; + } + + if (first_error_.ok()) { + first_error_ = error; + first_error_backtrace_.CreateCurrent(/*skip_count=*/1); + } +} + +XlaOp XlaBuilder::NoteErrorOrReturn( + const std::function()>& op_creator) { + if (!first_error_.ok()) { + return {}; + } + auto op = op_creator(); + if (!op.ok()) { + NoteError(op.status()); + return {}; + } + return op.ConsumeValueOrDie(); +} + +StatusOr XlaBuilder::GetProgramShape(int64* root_id) const { + TF_RETURN_IF_ERROR(first_error_); + + TF_RET_CHECK(root_id != nullptr); + + ProgramShape program_shape; + + // Not all instructions can be roots. Walk backwards from the last added + // instruction until a valid root is found. + int64 index = instructions_.size() - 1; + for (; index >= 0; index--) { + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instructions_[index].opcode())); + if (CanBeRoot(opcode)) { + break; + } + } + if (index < 0) { + return FailedPrecondition("no root instruction was found"); + } + *root_id = instructions_[index].id(); + *program_shape.mutable_result() = instructions_[index].shape(); + + // Check that the parameter numbers are continuous from 0, and add parameter + // shapes and names to the program shape. + const int64 param_count = parameter_numbers_.size(); + for (int64 i = 0; i < param_count; i++) { + program_shape.add_parameters(); + program_shape.add_parameter_names(); + } + for (const HloInstructionProto& instr : instructions_) { + // Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So + // to verify continuity, we just need to verify that every parameter is in + // the right range. + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + const int64 index = instr.parameter_number(); + TF_RET_CHECK(index >= 0 && index < param_count) + << "invalid parameter number: " << index; + *program_shape.mutable_parameters(index) = instr.shape(); + *program_shape.mutable_parameter_names(index) = instr.name(); + } + } + return program_shape; +} + +StatusOr XlaBuilder::GetProgramShape() const { + int64 root; + return GetProgramShape(&root); +} + +void XlaBuilder::IsConstantVisitor(const int64 op_handle, + std::set* visited, + bool* is_constant) const { + if (visited->count(op_handle) != 0 || !*is_constant) { + return; + } + + CHECK(op_handle < instructions_.size() && op_handle >= 0); + + const HloInstructionProto& instr = instructions_[op_handle]; + const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); + switch (opcode) { + default: + for (const int64 operand_id : instr.operand_ids()) { + IsConstantVisitor(operand_id, visited, is_constant); + } + // TODO(b/32495713): We aren't checking the called computations. + break; + + // Non functional ops. + case HloOpcode::kRng: + case HloOpcode::kCrossReplicaSum: + // TODO(b/33009255): Implmement constant folding for cross replica sum. + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kHostCompute: + case HloOpcode::kCall: + // TODO(b/32495713): We aren't checking the to_apply computation itself, + // so we conservatively say that computations containing the Call op + // cannot be constant. We cannot set is_functional=false in other similar + // cases since we're already relying on IsConstant to return true. + case HloOpcode::kCustomCall: + case HloOpcode::kWhile: + // TODO(b/32495713): We aren't checking the condition and body + // computations themselves. + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kParameter: + *is_constant = false; + break; + } + if (!*is_constant) { + VLOG(1) << "Non-constant: " << instr.name(); + } + visited->insert(op_handle); +} + +XlaComputation XlaBuilder::BuildAndNoteError() { + DCHECK(parent_builder_ != nullptr); + auto build_status = Build(); + if (!build_status.ok()) { + parent_builder_->NoteError( + AddStatus(build_status.status(), + tensorflow::strings::StrCat("error from: ", name_))); + return {}; + } + return build_status.ConsumeValueOrDie(); +} + +StatusOr XlaBuilder::Build() { + if (!first_error_.ok()) { + string backtrace; + first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); + return AppendStatus(first_error_, backtrace); + } + + HloComputationProto entry; + entry.set_id(GetUniqueId()); // Give the computation a global unique id. + entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique. + + { + int64 root_id; + TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), + GetProgramShape(&root_id)); + entry.set_root_id(root_id); + } + + for (auto& instruction : instructions_) { + // Ensures that the instruction names are unique among the whole graph. + const string& new_name = + StrCat(instruction.name(), ".", entry.id(), ".", instruction.id()); + instruction.set_name(new_name); + entry.add_instructions()->Swap(&instruction); + } + + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_program_shape() = entry.program_shape(); + for (auto& e : embedded_) { + module->add_computations()->Swap(&e.second); + } + module->add_computations()->Swap(&entry); + + // Clear data held by this builder. + this->instructions_.clear(); + this->embedded_.clear(); + this->parameter_numbers_.clear(); + + return std::move(computation); +} + +StatusOr XlaBuilder::InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + TF_RETURN_IF_ERROR(first_error_); + + HloInstructionProto instr; + *instr.mutable_shape() = shape; + for (int64 dim : broadcast_dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand}); +} + +StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + + CHECK(ShapeUtil::IsScalar(operand_shape) || + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = + ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type()); + + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand_shape)) { + return InDimBroadcast(broadcast_shape, operand, {}); + } + + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) { + if (operand_shape.dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand_shape.dimensions(i)); + } else { + TF_RET_CHECK(operand_shape.dimensions(i) == 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand shape: " + << operand_shape << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + Reshape(ShapeUtil::MakeShape(operand_shape.element_type(), + reshaped_dimensions), + operand)); + // Broadcast 'reshape' up to the larger size. + return InDimBroadcast(broadcast_shape, reshaped_operand, + broadcast_dimensions); +} + +XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferUnaryOpShape(unop, operand_shape)); + return AddInstruction(std::move(instr), unop, {operand}); + }); +} + +XlaOp XlaBuilder::BinaryOp( + HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBinaryOpShape( + binop, lhs_shape, rhs_shape, broadcast_dimensions)); + + const int64 lhs_rank = ShapeUtil::Rank(lhs_shape); + const int64 rhs_rank = ShapeUtil::Rank(rhs_shape); + + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + + if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) { + const bool should_broadcast_lhs = lhs_rank < rhs_rank; + XlaOp from = should_broadcast_lhs ? lhs : rhs; + const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape; + + std::vector to_size; + for (int64 size : instr.shape().dimensions()) { + to_size.push_back(size); + } + for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape); + from_dim++) { + int64 to_dim = broadcast_dimensions[from_dim]; + to_size[to_dim] = from_shape.dimensions(from_dim); + } + + const Shape& broadcasted_shape = + ShapeUtil::MakeShape(from_shape.element_type(), to_size); + TF_ASSIGN_OR_RETURN( + XlaOp broadcasted_operand, + InDimBroadcast(broadcasted_shape, from, broadcast_dimensions)); + + updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs; + updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; + } + + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), updated_lhs)); + } + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), updated_rhs)); + } + + return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs}); + }); +} + +XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferTernaryOpShape( + triop, lhs_shape, rhs_shape, ehs_shape)); + XlaOp updated_lhs = lhs; + XlaOp updated_rhs = rhs; + XlaOp updated_ehs = ehs; + if (!ShapeUtil::IsTuple(instr.shape())) { + if (!ShapeUtil::IsTuple(lhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) { + // lhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_lhs, + AddBroadcastSequence(instr.shape(), lhs)); + } + if (!ShapeUtil::IsTuple(rhs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) { + // rhs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_rhs, + AddBroadcastSequence(instr.shape(), rhs)); + } + if (!ShapeUtil::IsTuple(ehs_shape) && + !ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) { + // ehs is being implicitly broadcasted. Change to explicit. + TF_ASSIGN_OR_RETURN(updated_ehs, + AddBroadcastSequence(instr.shape(), ehs)); + } + } + return AddInstruction(std::move(instr), triop, + {updated_lhs, updated_rhs, updated_ehs}); + }); +} + +XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = literal.shape(); + *instr.mutable_literal() = literal.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kConstant); + }); +} + +XlaOp XlaBuilder::Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCallShape(operand_shape_ptrs, + /*to_apply=*/called_program_shape)); + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kCall, operands); + }); +} + +XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, + const string& name) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (!parameter_numbers_.insert(parameter_number).second) { + return InvalidArgument("parameter %lld already registered", + parameter_number); + } + instr.set_parameter_number(parameter_number); + instr.set_name(name); + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kParameter); + }); +} + +XlaOp XlaBuilder::Broadcast( + const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + const Shape& shape, + ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); + + // The client-level broadcast op just appends dimensions on the left (adds + // lowest numbered dimensions). The HLO broadcast instruction is more + // flexible and can add new dimensions anywhere. The instruction's + // dimensions field maps operand dimensions to dimensions in the broadcast + // output, so to append dimensions on the left the instruction's dimensions + // should just be the n highest dimension numbers of the output shape where + // n is the number of input dimensions. + const int64 operand_rank = ShapeUtil::Rank(operand_shape); + std::vector dimensions(operand_rank); + for (int i = 0; i < operand_rank; ++i) { + dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank; + } + return InDimBroadcast(shape, operand, dimensions); + }); +} + +StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { + TF_RETURN_IF_ERROR(first_error_); + + HloInstructionProto instr; + *instr.mutable_shape() = shape; + return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand}); +} + +XlaOp XlaBuilder::Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferSliceShape(operand_shape, start_indices, + limit_indices, strides)); + for (int i = 0; i < start_indices.size(); i++) { + auto* slice_config = instr.add_slice_dimensions(); + slice_config->set_start(start_indices[i]); + slice_config->set_limit(limit_indices[i]); + slice_config->set_stride(strides[i]); + } + + return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); + }); +} + +XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + std::vector starts(ShapeUtil::Rank(shape), 0); + std::vector limits(shape.dimensions().begin(), + shape.dimensions().end()); + std::vector strides(ShapeUtil::Rank(shape), 1); + starts[dimno] = start_index; + limits[dimno] = limit_index; + strides[dimno] = stride; + return Slice(operand, starts, limits, strides); + }); +} + +XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicSliceShape( + operand_shape, start_indices_shape, slice_sizes)); + + for (int64 size : slice_sizes) { + instr.add_dynamic_slice_sizes(size); + } + + return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice, + {operand, start_indices}); + }); +} + +XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDynamicUpdateSliceShape( + operand_shape, update_shape, start_indices_shape)); + + return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice, + {operand, update, start_indices}); + }); +} + +XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); + + instr.add_dimensions(dimension); + + return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); + }); +} + +XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape, + GetShape(padding_value)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferPadShape(operand_shape, padding_value_shape, + padding_config)); + + *instr.mutable_padding_config() = padding_config; + + return AddInstruction(std::move(instr), HloOpcode::kPad, + {operand, padding_value}); + }); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& shape, + ShapeInference::InferReshapeShape( + operand_shape, dimensions, new_sizes)); + XlaOp transposed = IsIdentityPermutation(dimensions) + ? operand + : Transpose(operand, dimensions); + return Reshape(shape, transposed); + }); +} + +XlaOp XlaBuilder::Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + std::vector dimensions(shape.dimensions_size()); + std::iota(dimensions.begin(), dimensions.end(), 0); + return Reshape(operand, dimensions, new_sizes); + }); +} + +XlaOp XlaBuilder::Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (dimensions.size() <= 1) { + // Not collapsing anything, trivially we can return the operand versus + // enqueueing a trivial reshape. + return operand; + } + + // Out-of-order collapse is not supported. + // Checks that the collapsed dimensions are in order and consecutive. + for (tensorflow::gtl::ArraySlice::size_type i = 1; + i < dimensions.size(); ++i) { + if (dimensions[i] - 1 != dimensions[i - 1]) { + return InvalidArgument( + "Collapsed dimensions are not in consecutive order."); + } + } + + // Create a new sizes vector from the old shape, replacing the collapsed + // dimensions by the product of their sizes. + TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand)); + + VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape); + VLOG(3) << "dims to collapse: " + << tensorflow::str_util::Join(dimensions, ","); + + std::vector new_sizes; + for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { + if (i <= dimensions.front() || i > dimensions.back()) { + new_sizes.push_back(original_shape.dimensions(i)); + } else { + new_sizes.back() *= original_shape.dimensions(i); + } + } + + VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") + << "]"; + + return Reshape(operand, new_sizes); + }); +} + +void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); + }); +} + +XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, + const XlaOp& on_false) { + return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false); +} + +XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferVariadicOpShape( + HloOpcode::kTuple, operand_shape_ptrs)); + return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); + }); +} + +XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data)); + if (!ShapeUtil::IsTuple(tuple_shape)) { + return InvalidArgument( + "Operand to GetTupleElement() is not a tuple; got %s", + ShapeUtil::HumanString(tuple_shape).c_str()); + } + *instr.mutable_shape() = + ShapeUtil::GetTupleElementShape(tuple_shape, index); + + instr.set_tuple_index(index); + + return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, + {tuple_data}); + }); +} + +XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape.dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); + }); +} + +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, + dimension_numbers)); + *instr.mutable_dot_dimension_numbers() = dimension_numbers; + return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); + }); +} + +Status XlaBuilder::VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const { + if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { + return InvalidArgument( + "Convolution arguments must have same number of " + "dimensions. Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str()); + } + int num_dims = ShapeUtil::Rank(lhs_shape); + if (num_dims < 2) { + return InvalidArgument( + "Convolution expects argument arrays with >= 3 dimensions. " + "Got: %s and %s", + ShapeUtil::HumanString(lhs_shape).c_str(), + ShapeUtil::HumanString(rhs_shape).c_str()); + } + int num_spatial_dims = num_dims - 2; + + const auto check_spatial_dimensions = + [&](const char* const field_name, + const tensorflow::protobuf::RepeatedField& + numbers) { + if (numbers.size() != num_spatial_dims) { + return InvalidArgument("Expected %d elements for %s, but got %d.", + num_spatial_dims, field_name, numbers.size()); + } + for (int i = 0; i < numbers.size(); ++i) { + if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { + return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + field_name, i, numbers.Get(i)); + } + } + return Status::OK(); + }; + TF_RETURN_IF_ERROR( + check_spatial_dimensions("input_spatial_dimensions", + dimension_numbers.input_spatial_dimensions())); + TF_RETURN_IF_ERROR( + check_spatial_dimensions("kernel_spatial_dimensions", + dimension_numbers.kernel_spatial_dimensions())); + return check_spatial_dimensions( + "output_spatial_dimensions", + dimension_numbers.output_spatial_dimensions()); +} + +XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding) { + return ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); +} + +XlaOp XlaBuilder::ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return ConvGeneral(lhs, rhs, window_strides, padding, + CreateDefaultConvDimensionNumbers(window_strides.size())); +} + +XlaOp XlaBuilder::ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + + TF_RETURN_IF_ERROR( + VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers)); + + std::vector base_area_dimensions( + dimension_numbers.input_spatial_dimensions_size()); + for (std::vector::size_type i = 0; i < base_area_dimensions.size(); + ++i) { + base_area_dimensions[i] = + lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i)); + } + + std::vector window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (std::vector::size_type i = 0; i < window_dimensions.size(); + ++i) { + window_dimensions[i] = + rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + + return ConvGeneral(lhs, rhs, window_strides, + MakePadding(base_area_dimensions, window_dimensions, + window_strides, padding), + dimension_numbers); + }); +} + +XlaOp XlaBuilder::ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers) { + return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, + dimension_numbers); +} + +XlaOp XlaBuilder::ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_RETURN_IF_ERROR( + VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers)); + + std::vector window_dimensions( + dimension_numbers.kernel_spatial_dimensions_size()); + for (std::vector::size_type i = 0; i < window_dimensions.size(); + ++i) { + window_dimensions[i] = + rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + lhs_dilation, rhs_dilation)); + + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), + dimension_numbers)); + + *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + + return AddInstruction(std::move(instr), HloOpcode::kConvolution, + {lhs, rhs}); + }); +} + +StatusOr XlaBuilder::MakeWindow( + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation) const { + const auto verify_size = [&](const size_t x, const char* x_name) { + if (x == 0 || x == window_dimensions.size()) { + return Status::OK(); + } else { + return InvalidArgument( + "%s", tensorflow::strings::StrCat( + "Window has different number of window dimensions than of ", + x_name, + "\nNumber of window dimensions: ", window_dimensions.size(), + "\nNumber of ", x_name, ": ", x, "\n") + .c_str()); + } + }; + TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); + TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); + TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); + TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); + + Window window; + for (size_t i = 0; i < window_dimensions.size(); i++) { + auto dim = window.add_dimensions(); + dim->set_size(window_dimensions[i]); + if (!window_strides.empty()) { + dim->set_stride(window_strides[i]); + } else { + dim->set_stride(1); + } + if (!padding.empty()) { + dim->set_padding_low(padding[i].first); + dim->set_padding_high(padding[i].second); + } else { + dim->set_padding_low(0); + dim->set_padding_high(0); + } + if (!lhs_dilation.empty()) { + dim->set_base_dilation(lhs_dilation[i]); + } else { + dim->set_base_dilation(1); + } + if (!rhs_dilation.empty()) { + dim->set_window_dilation(rhs_dilation[i]); + } else { + dim->set_window_dilation(1); + } + dim->set_window_reversal(false); + } + return window; +} + +XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); + + instr.set_fft_type(fft_type); + for (int64 i : fft_length) { + instr.add_fft_length(i); + } + + return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); + }); +} + +XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Infeed must have a layout"); + } + *instr.mutable_shape() = shape; + instr.set_infeed_config(config); + return AddInstruction(std::move(instr), HloOpcode::kInfeed); + }); +} + +void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + *instr.mutable_shape() = ShapeUtil::MakeNil(); + + // Check and set outfeed shape. + if (!LayoutUtil::HasLayout(shape_with_layout)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { + return InvalidArgument( + "Outfeed shape %s must be compatible with operand shape %s", + ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + } + *instr.mutable_outfeed_shape() = shape_with_layout; + + instr.set_outfeed_config(outfeed_config); + + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}); + }); +} + +XlaOp XlaBuilder::CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + if (tensorflow::str_util::StartsWith(call_target_name, "$")) { + return InvalidArgument( + "Invalid custom_call_target \"%s\": Call targets that start with '$' " + "are reserved for internal use.", + call_target_name.c_str()); + } + *instr.mutable_shape() = shape; + instr.set_custom_call_target(call_target_name); + return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); + }); +} + +XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, + int64 cost_estimate_ns, const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.set_channel_name(channel_name); + instr.set_cost_estimate_ns(cost_estimate_ns); + return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands); + }); +} + +XlaOp XlaBuilder::Complex( + const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); +} + +XlaOp XlaBuilder::Conj(const XlaOp& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} + +XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return Or(And(Not(lhs), rhs, broadcast_dimensions), + And(lhs, Not(rhs), broadcast_dimensions)); +} + +XlaOp XlaBuilder::Not(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNot, operand); +} + +XlaOp XlaBuilder::ShiftLeft( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, + broadcast_dimensions); +} + +XlaOp XlaBuilder::Abs(const XlaOp& operand) { + return UnaryOp(HloOpcode::kAbs, operand); +} + +XlaOp XlaBuilder::Atan2( + const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); +} + +XlaOp XlaBuilder::Exp(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExp, operand); +} + +XlaOp XlaBuilder::Floor(const XlaOp& operand) { + return UnaryOp(HloOpcode::kFloor, operand); +} + +XlaOp XlaBuilder::Ceil(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCeil, operand); +} + +XlaOp XlaBuilder::Round(const XlaOp& operand) { + return UnaryOp(HloOpcode::kRoundNearestAfz, operand); +} + +XlaOp XlaBuilder::Log(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog, operand); +} + +XlaOp XlaBuilder::Sign(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSign, operand); +} + +XlaOp XlaBuilder::Cos(const XlaOp& operand) { + return UnaryOp(HloOpcode::kCos, operand); +} + +XlaOp XlaBuilder::Sin(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSin, operand); +} + +XlaOp XlaBuilder::Tanh(const XlaOp& operand) { + return UnaryOp(HloOpcode::kTanh, operand); +} + +XlaOp XlaBuilder::Real(const XlaOp& operand) { + return UnaryOp(HloOpcode::kReal, operand); +} + +XlaOp XlaBuilder::Imag(const XlaOp& operand) { + return UnaryOp(HloOpcode::kImag, operand); +} + +XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { + return UnaryOp(HloOpcode::kIsFinite, operand); +} + +XlaOp XlaBuilder::Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferTransposeShape(operand_shape, permutation)); + for (int64 dim : permutation) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kTranspose, {operand}); + }); +} + +XlaOp XlaBuilder::Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReverseShape(operand_shape, dimensions)); + for (int64 dim : dimensions) { + instr.add_dimensions(dim); + } + return AddInstruction(std::move(instr), HloOpcode::kReverse, {operand}); + }); +} + +XlaOp XlaBuilder::Sort(const XlaOp& operand) { + return UnaryOp(HloOpcode::kSort, operand); +} + +XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(0.5), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); +} + +XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvertShape(operand_shape, new_element_type)); + return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand}); + }); +} + +XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConvertShape(operand_shape, new_element_type)); + return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, + {operand}); + }); +} + +XlaOp XlaBuilder::SquareF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(2.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) { + return BinaryOp(HloOpcode::kPower, operand, ConstantR0(-1.0), + /*broadcast_dimensions=*/{}); +} + +XlaOp XlaBuilder::Neg(const XlaOp& operand) { + return UnaryOp(HloOpcode::kNegate, operand); +} + +XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, + const XlaOp& max) { + return TernaryOp(HloOpcode::kClamp, min, operand, max); +} + +XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (!static_operands.empty()) { + return Unimplemented("static_operands is not supported in Map"); + } + + HloInstructionProto instr; + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape, + dimensions)); + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kMap, operands); + }); +} + +XlaOp XlaBuilder::RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Check the number of parameters per RNG distribution. + switch (distribution) { + case RandomDistribution::RNG_NORMAL: + case RandomDistribution::RNG_UNIFORM: + if (parameters.size() != 2) { + return InvalidArgument( + "RNG distribution (%s) expects 2 parameters, but got %ld", + RandomDistribution_Name(distribution).c_str(), parameters.size()); + } + break; + default: + LOG(FATAL) << "unhandled distribution " << distribution; + } + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + *instr.mutable_shape() = shape; + + instr.set_distribution(distribution); + + return AddInstruction(std::move(instr), HloOpcode::kRng, parameters); + }); +} + +XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); +} + +XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b, + const Shape& shape) { + return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); +} + +XlaOp XlaBuilder::While(const XlaComputation& condition, + const XlaComputation& body, const XlaOp& init) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Infer shape. + TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, + condition.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferWhileShape(condition_program_shape, + body_program_shape, init_shape)); + // Body comes before condition computation in the vector. + AddCalledComputation(body, &instr); + AddCalledComputation(condition, &instr); + return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); + }); +} + +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, + GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferGatherShape(input_shape, gather_indices_shape, + dimension_numbers, window_bounds)); + + *instr.mutable_gather_dimension_numbers() = dimension_numbers; + for (int64 bound : window_bounds) { + instr.add_gather_window_bounds(bound); + } + + return AddInstruction(std::move(instr), HloOpcode::kGather, + {input, gather_indices}); + }); +} + +XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate)); + TF_ASSIGN_OR_RETURN(const Shape& true_operand_shape, + GetShape(true_operand)); + TF_ASSIGN_OR_RETURN(const ProgramShape& true_computation_shape, + true_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const Shape& false_operand_shape, + GetShape(false_operand)); + TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape, + false_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferConditionalShape( + predicate_shape, true_operand_shape, false_operand_shape, + true_computation_shape, false_computation_shape)); + + // The index of true_computation must be 0 and that of false computation + // must be 1. + AddCalledComputation(true_computation, &instr); + AddCalledComputation(false_computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kConditional, + {predicate, true_operand, false_operand}); + }); +} + +XlaOp XlaBuilder::Reduce( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape, init_shape, dimensions_to_reduce, + called_program_shape)); + + for (int64 dim : dimensions_to_reduce) { + instr.add_dimensions(dim); + } + + AddCalledComputation(computation, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kReduce, + {operand, init_value}); + }); +} + +XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + std::vector all_dimnos(ShapeUtil::Rank(operand_shape)); + std::iota(all_dimnos.begin(), all_dimnos.end(), 0); + return Reduce(operand, init_value, computation, all_dimnos); + }); +} + +XlaOp XlaBuilder::ReduceWindow( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_RETURN_IF_ERROR( + ValidatePaddingValues(AsInt64Slice(operand_shape.dimensions()), + window_dimensions, window_strides)); + + std::vector> padding_values = + MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, + window_strides, padding); + return ReduceWindowWithGeneralPadding(operand, init_value, computation, + window_dimensions, window_strides, + padding_values); + }); +} + +XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReduceWindowShape(operand_shape, init_shape, + instr.window(), to_apply_shape)); + + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kReduceWindow, + {operand, init_value}); + }); +} + +XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); + TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferBatchNormTrainingShape( + operand_shape, scale_shape, offset_shape, feature_index)); + + instr.set_epsilon(epsilon); + instr.set_feature_index(feature_index); + + return AddInstruction(std::move(instr), HloOpcode::kBatchNormTraining, + {operand, scale, offset}); + }); +} + +XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); + TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset)); + TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean)); + TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBatchNormInferenceShape( + operand_shape, scale_shape, offset_shape, + mean_shape, variance_shape, feature_index)); + + instr.set_epsilon(epsilon); + instr.set_feature_index(feature_index); + + return AddInstruction(std::move(instr), HloOpcode::kBatchNormInference, + {operand, scale, offset, mean, variance}); + }); +} + +XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale)); + TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean)); + TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var)); + TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferBatchNormGradShape( + operand_shape, scale_shape, batch_mean_shape, + batch_var_shape, grad_output_shape, feature_index)); + + instr.set_epsilon(epsilon); + instr.set_feature_index(feature_index); + + return AddInstruction(std::move(instr), HloOpcode::kBatchNormGrad, + {operand, scale, batch_mean, batch_var, grad_output}); + }); +} + +XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + + return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, + {operand}); + }); +} + +XlaOp XlaBuilder::SelectAndScatter( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return NoteErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + return SelectAndScatterWithGeneralPadding( + operand, select, window_dimensions, window_strides, + MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, + window_strides, padding), + source, init_value, scatter); + }); +} + +XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(const Shape& source_shape, GetShape(source)); + TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); + TF_ASSIGN_OR_RETURN(const ProgramShape& select_shape, + select.GetProgramShape()); + TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape, + scatter.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_window(), + MakeWindow(window_dimensions, window_strides, padding, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferSelectAndScatterShape( + operand_shape, select_shape, instr.window(), + source_shape, init_shape, scatter_shape)); + + AddCalledComputation(select, &instr); + AddCalledComputation(scatter, &instr); + + return AddInstruction(std::move(instr), HloOpcode::kSelectAndScatter, + {operand, source, init_value}); + }); +} + +XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferReducePrecisionShape( + operand_shape, exponent_bits, mantissa_bits)); + instr.set_exponent_bits(exponent_bits); + instr.set_mantissa_bits(mantissa_bits); + return AddInstruction(std::move(instr), HloOpcode::kReducePrecision, + {operand}); + }); +} + +void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) { + NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Send instruction produces a tuple of {aliased operand, U32 context}. + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN( + XlaOp send, + AddInstruction(std::move(instr), HloOpcode::kSend, {operand})); + + HloInstructionProto send_done_instr; + *send_done_instr.mutable_shape() = ShapeUtil::MakeNil(); + send_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone, + {send}); + }); +} + +XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { + return NoteErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + + // Recv instruction produces a tuple of {receive buffer, U32 context}. + *instr.mutable_shape() = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + instr.set_channel_id(handle.handle()); + TF_ASSIGN_OR_RETURN(XlaOp recv, + AddInstruction(std::move(instr), HloOpcode::kRecv, {})); + + HloInstructionProto recv_done_instr; + *recv_done_instr.mutable_shape() = shape; + recv_done_instr.set_channel_id(handle.handle()); + return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone, + {recv}); + }); +} + +StatusOr XlaBuilder::IsConstant(const XlaOp& operand) const { + TF_RETURN_IF_ERROR(first_error_); + + // Verify that the handle is valid. + TF_RETURN_IF_ERROR(LookUpInstruction(operand).status()); + + bool is_constant = true; + std::set visited; + IsConstantVisitor(operand.handle(), &visited, &is_constant); + return is_constant; +} + +StatusOr XlaBuilder::BuildConstantSubGraph( + const XlaOp& root_op) const { + TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); + if (!is_constant) { + auto op_status = LookUpInstruction(root_op); + string op_string = + op_status.ok() ? op_status.ValueOrDie()->name() : ""; + return InvalidArgument( + "Operand to BuildConstantSubGraph depends on a parameter.\n\n" + " op requested for constant subgraph: %s\n\n" + "This is an internal error that typically happens when the XLA user " + "(e.g. TensorFlow) is attempting to determine a value that must be a " + "compile-time constant (e.g. an array dimension) but it is not capable " + "of being evaluated at XLA compile time.\n\n" + "Please file a usability bug with the framework being used (e.g. " + "TensorFlow).", + op_string.c_str()); + } + + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + LookUpInstruction(root_op)); + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(root->opcode())); + if (!CanBeRoot(opcode)) { + return InvalidArgument("the operand with opcode %s cannot be root", + root->opcode().c_str()); + } + + HloComputationProto entry; + entry.set_id(GetUniqueId()); // Give the computation a global unique id. + entry.set_name(StrCat(name_, entry.id(), "_compute_constant")); + entry.set_root_id(root->id()); + ProgramShape* program_shape = entry.mutable_program_shape(); + *program_shape->mutable_result() = root->shape(); + + // We use std::set to keep the instruction ids in ascending order (which is + // also a valid denpendency order). The related ops will be added to the + // subgraph in the same order. + std::set related_ops; + tensorflow::gtl::FlatSet related_calls; // Related computations. + std::queue worklist; + worklist.push(root->id()); + related_ops.insert(root->id()); + while (!worklist.empty()) { + int64 node = worklist.front(); + worklist.pop(); + for (int64 id : instructions_[node].operand_ids()) { + if (related_ops.insert(id).second) { + worklist.push(id); + } + } + for (int64 called_id : instructions_[node].called_computation_ids()) { + related_calls.insert(called_id); + } + } + + // Add related ops to the computation. + for (int64 id : related_ops) { + auto* instr = entry.add_instructions(); + *instr = instructions_[id]; + // Ensures that the instruction names are unique among the graph. + const string& new_name = + StrCat(instr->name(), ".", entry.id(), ".", instr->id()); + instr->set_name(new_name); + } + + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_program_shape() = *program_shape; + for (auto& e : embedded_) { + if (related_calls.find(e.second.id()) != related_calls.end()) { + *module->add_computations() = e.second; + } + } + *module->add_computations() = std::move(entry); + + return std::move(computation); +} + +std::unique_ptr XlaBuilder::CreateSubBuilder( + const string& computation_name) { + auto sub_builder = MakeUnique(computation_name); + sub_builder->parent_builder_ = this; + sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; + return sub_builder; +} + +/* static */ ConvolutionDimensionNumbers +XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_kernel_output_feature_dimension( + kConvKernelOutputDimension); + dimension_numbers.set_kernel_input_feature_dimension( + kConvKernelInputDimension); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(i + 2); + dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); + } + return dimension_numbers; +} + +/* static */ Status XlaBuilder::Validate( + const ConvolutionDimensionNumbers& dnum) { + if (dnum.input_spatial_dimensions_size() < 2) { + return FailedPrecondition("input spacial dimension < 2: %d", + dnum.input_spatial_dimensions_size()); + } + if (dnum.kernel_spatial_dimensions_size() < 2) { + return FailedPrecondition("kernel spacial dimension < 2: %d", + dnum.kernel_spatial_dimensions_size()); + } + if (dnum.output_spatial_dimensions_size() < 2) { + return FailedPrecondition("output spacial dimension < 2: %d", + dnum.output_spatial_dimensions_size()); + } + + if (std::set( + {dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the input are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.input_batch_dimension(), dnum.input_feature_dimension(), + dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); + } + if (std::set({dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), + dnum.kernel_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.kernel_output_feature_dimension(), + dnum.kernel_input_feature_dimension(), + dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); + } + if (std::set({dnum.output_batch_dimension(), + dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), + dnum.output_spatial_dimensions(1)}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + dnum.output_batch_dimension(), dnum.output_feature_dimension(), + dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); + } + return Status::OK(); +} + +StatusOr XlaBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + TF_RETURN_IF_ERROR(first_error_); + + const int64 handle = instructions_.size(); + instr.set_id(handle); + instr.set_opcode(HloOpcodeString(opcode)); + if (instr.name().empty()) { + instr.set_name(StrCat(instr.opcode())); + } + for (const auto& operand : operands) { + if (operand.builder_ == nullptr) { + return InvalidArgument("invalid XlaOp with handle %lld", + operand.handle()); + } + if (operand.builder_ != this) { + return InvalidArgument("Do not add XlaOp from builder %s to builder %s", + operand.builder_->name().c_str(), + this->name().c_str()); + } + instr.add_operand_ids(operand.handle()); + } + + *instr.mutable_metadata() = metadata_; + if (sharding_) { + *instr.mutable_sharding() = *sharding_; + } + + instructions_.push_back(instr); + + XlaOp op(handle, this); + return op; +} + +void XlaBuilder::AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr) { + instr->add_called_computation_ids(computation.proto().entry_computation_id()); + for (const HloComputationProto& e : computation.proto().computations()) { + embedded_.insert({e.id(), e}); + } +} + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + + if (op.builder_ != this) { + return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + } + + TF_RET_CHECK(op.builder_ == this); + if (op.handle() >= instructions_.size() || op.handle() < 0) { + return InvalidArgument("no XlaOp value %lld", op.handle()); + } + return &instructions_[op.handle()]; +} + +XlaOp XlaBuilder::UnimplementedOp() { + NoteError(Unimplemented("Op not implemented")); + return {}; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..1f7c731064dc004adcac56547e4717ff1638a491 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -0,0 +1,996 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(b/74197823): Replace computation_builder.h with this file. +// +// This is NOT YET ready to use. + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stacktrace.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class XlaBuilder; + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +// +// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. +class XlaOp { + public: + XlaOp() : handle_(0), builder_(nullptr) {} + ~XlaOp() {} + + StatusOr GetShape() const; + + private: + XlaOp(int64 handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64 handle() const { return handle_; } + friend class XlaBuilder; + + int64 handle_; + XlaBuilder* builder_; // Not owned. +}; + +// A convenient interface for building up computations. +// +// Thread-compatible. +// +// TODO(b/74197823): Replace xla::ComputationBuilder with this one. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + XlaBuilder(const string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + ~XlaBuilder(); + + // Returns the computation name. + const string& name() const { return name_; } + + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the Computation Builder. All subsequent + // instructions generated via this Computation Builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Enqueues a "retrieve parameter value" instruction for a parameter that was + // passed to the computation. + XlaOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + // Enqueues a constant with the value of the given literal onto the + // computation. + XlaOp ConstantLiteral(const Literal& literal); + + // Enqueues a constant onto the computation. Methods are templated on the + // native host type (NativeT) which corresponds to a specific XLA + // PrimitiveType as given in the following table: + // + // Native Type PrimitiveType + // ----------------------------- + // bool PRED + // int32 S32 + // int64 S64 + // uint32 U32 + // uint64 U64 + // float F32 + // double F64 + // + // Note: not all primitive types defined in xla_data.proto have a + // corresponding native type yet. + template + XlaOp ConstantR0(NativeT value); + template + XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(const tensorflow::core::Bitmap& values); + template + XlaOp ConstantR2( + std::initializer_list> values); + template + XlaOp ConstantFromArrayWithLayout(const Array& values, + const Layout& layout); + template + XlaOp ConstantFromArray(const Array& values); + template + XlaOp ConstantR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); + template + XlaOp ConstantR2FromArray2D(const Array2D& values); + template + XlaOp ConstantR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); + template + XlaOp ConstantR3FromArray3D(const Array3D& values); + template + XlaOp ConstantR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); + template + XlaOp ConstantR4FromArray4D(const Array4D& values); + + // Enqueues a rank one constant (vector) onto the computation. The vector has + // size 'length' and every element has the value 'value'. + template + XlaOp ConstantR1(int64 length, NativeT value); + + // Adds dimensions to an array by duplicating the data in the array. + // + // The new dimensions are inserted on the left, i.e. if + // broadcast_sizes has values {a0, ..., aN} and the operand shape + // has dimensions {b0, ..., bM} then the shape of the output has + // dimensions {a0, ..., aN, b0, ..., bM}. + // + // The new dimensions index into copies of the operand, i.e. + // + // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] + XlaOp Broadcast(const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + // Enqueues a pad operation onto the computation that pads the given value on + // the edges as well as between the elements of the input. padding_config + // specifies the padding amount for each dimension. + XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, + const PaddingConfig& padding_config); + + // Enqueues an operation onto the computation that flattens the operand based + // on the dimension order (major/slowest-varying to minor/fastest-varying) + // given, followed by reshaping it into the shape with the given dimension + // sizes (also major to minor). Conceptually, this is a limited form of + // "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + // Enqueues an operation onto the computation that collapses the operand, from + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". + XlaOp Reshape(const XlaOp& operand, + tensorflow::gtl::ArraySlice new_sizes); + + // Wrapper for Reshape. + // Enqueues an operation to collapse the provided dimensions; e.g. an + // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to + // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must + // be a consecutive, in-order subsequence of the operand dimensions. + // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // + // This could potentially cause data to be moved -- it provides a more + // structured form of reshaping than an arbitrary Reshape operation. + XlaOp Collapse(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a slice operation onto the computation that slices the operand + // from the start indices to the limit indices; e.g. + // + // x + // [ 0 1 2 3 ] + // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] + // [ 8 9 a b ] + // + // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D + // range notation. + // The strides parameter determines the stride over the slice + XlaOp Slice(const XlaOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + // Enqueues a slice operation in a given dimension, taking all other + // dimensions as they are; e.g. if dimno is 1 from start_index 2 to + // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand + // for: + // + // array[:, 2:4:1, :] + XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + + // Enqueues a slice operation onto the computation that slices the 'operand' + // from dynamic start indices which are passed in 'start_indices'. + // The size of the slice in each dimension is passed in 'slice_sizes', + // which specify the end point of exclusive slice intervals in each + // dimension [start, start + size). + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo input dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + // Enqueues a dynamic update slice operation onto the computation, which + // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. + // The shape of 'update' determines the shape of the slice of 'operand' + // which is updated. + // The indices specified in 'start_indices' specify the offset of the slice + // of 'operand' which is updated. + // + // update = {10, 11} // calculated at runtime. + // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] + // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] + // [7 8 9] [7 8 9 ] + // + // The shape of 'start_indices' must be rank == 1, with dimension size + // equal to the rank of the 'operand'. + // Slice index calculations are computed modulo update dimension sizes to + // prevent dynamic start indices from generating out-of-bound array accesses. + XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, + const XlaOp& start_indices); + + // Enqueues a concatenate instruction onto the computation. 'operands' must + // have >= 1 entry. + XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); + + // Enqueue a tracing operation onto the computation; the computation will emit + // a logging message with the operand. + void Trace(const string& tag, const XlaOp& operand); + + // Enqueues a conditional-move-like select operation onto the computation; + // predicated on pred, selects between on_true and on_false. + XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); + + // Enqueues a tuple-creation instruction onto the computation. + XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + + // Enqueues a tuple-element-get instruction onto the computation. + XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); + + // Enqueues an equal-to comparison instruction onto the computation. + XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a not-equal comparison instruction onto the computation. + XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-or-equal comparison instruction onto the computation. + XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a greater-than comparison instruction onto the computation. + XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-than comparison instruction onto the computation. + XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a less-or-equal comparison instruction onto the computation. + XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a dot instruction onto the computation. + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + + // Enqueues a general dot instruction onto the computation. + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers); + + // Default dimension numbers used for a 2D convolution. + static constexpr int64 kConvBatchDimension = 0; + static constexpr int64 kConvFeatureDimension = 1; + static constexpr int64 kConvFirstSpatialDimension = 2; + static constexpr int64 kConvSecondSpatialDimension = 3; + static constexpr int64 kConvKernelOutputDimension = 0; + static constexpr int64 kConvKernelInputDimension = 1; + static constexpr int64 kConvKernelFirstSpatialDimension = 2; + static constexpr int64 kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Enqueues a convolution instruction onto the computation, which uses the + // default convolution dimension numbers. + XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration in the format returned by MakePadding(). + XlaOp ConvWithGeneralPadding( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided dimension numbers configuration. + XlaOp ConvWithGeneralDimensions( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration as well as the dimension numbers. + XlaOp ConvGeneral( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues a convolution instruction onto the computation, with the caller + // provided padding configuration, dilation factors and dimension numbers. + XlaOp ConvGeneralDilated( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + XlaOp Fft(const XlaOp& operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. + XlaOp Infeed(const Shape& shape, const string& config = ""); + + // Enqueues an outfeed instruction onto the computation. This instruction + // generates outgoing data transfers for the given data. + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, + const string& outfeed_config); + + // Enqueues a call instruction onto the computation. + XlaOp Call(const XlaComputation& computation, + tensorflow::gtl::ArraySlice operands); + + // Enqueues a custom call instruction onto the computation. + // During code generation, a call instruction is emitted which targets a + // symbol with the name |call_target_name|. The |operands| are passed to the + // call instruction. |shape| is the resultant shape. + XlaOp CustomCall(const string& call_target_name, + tensorflow::gtl::ArraySlice operands, + const Shape& shape); + + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, + const Shape& shape); + + // The following methods enqueue element-wise binary arithmetic operations + // onto the computation. The shapes of the operands have to match unless one + // of the operands is a scalar, or an explicit broadcast dimension is given + // (see g3doc for more details). + + // Enqueues a complex compose instruction onto the computation. + XlaOp Complex(const XlaOp& real, const XlaOp& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + XlaOp Conj(const XlaOp& operand); + + // Enqueues an add instruction onto the computation. + XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a subtract instruction onto the computation. + XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a multiply instruction onto the computation. + XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a divide instruction onto the computation. + XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a remainder instruction onto the computation. + XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a max instruction onto the computation. + XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a min instruction onto the computation. + XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Element-wise logical operators + XlaOp And(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + XlaOp Not(const XlaOp& operand); + + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + XlaOp ShiftRightLogical( + const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Reduces an array among the provided dimensions, given "computation" as a + // reduction operator. + XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + // Convenience wrapper around the above that reduces all the dimensions in the + // operand shape. + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation); + + // Enqueues a windowed reduce instruction onto the computation. + XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding); + + // As ReduceWindow(), but the padding is given in the format + // returned by MakePadding(). + XlaOp ReduceWindowWithGeneralPadding( + const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding); + + // Returns the sum of the operand value across all replicas. All replicas + // supply one input to the sum and all replicas receive the resulting sum. + XlaOp CrossReplicaSum(const XlaOp& operand); + + // Enqueues an operation that scatters the `source` array to the selected + // indices of each window. + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); + + // As SelectAndScatter(), but the padding is given in the format + // returned by MakePadding(). + XlaOp SelectAndScatterWithGeneralPadding( + const XlaOp& operand, const XlaComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); + + // Enqueues an abs instruction onto the computation. + XlaOp Abs(const XlaOp& operand); + + // Enqueues a atan2 instruction onto the computation. + XlaOp Atan2(const XlaOp& y, const XlaOp& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an exp instruction onto the computation. + XlaOp Exp(const XlaOp& operand); + + // Enqueues a floor instruction onto the computation. + XlaOp Floor(const XlaOp& operand); + + // Enqueues a ceil instruction onto the computation. + XlaOp Ceil(const XlaOp& operand); + + // Enqueues a round instruction onto the computation, rounding to nearest even + // with half-way cases rounding away from zero. + XlaOp Round(const XlaOp& operand); + + // Enqueues an log instruction (natural logarithm) onto the computation. + XlaOp Log(const XlaOp& operand); + + // Enqueues a sign instruction onto the computation. + XlaOp Sign(const XlaOp& operand); + + // Enqueues a cosine instruction onto the computation. + XlaOp Cos(const XlaOp& operand); + + // Enqueues a sine instruction onto the computation. + XlaOp Sin(const XlaOp& operand); + + // Enqueues a tanh instruction onto the computation. + XlaOp Tanh(const XlaOp& operand); + + // Enqueues a real-part instruction onto the computation. + XlaOp Real(const XlaOp& operand); + + // Enqueues an imaginary-part instruction onto the computation. + XlaOp Imag(const XlaOp& operand); + + // Enqueues a float32 sqrt instruction onto the computation. + // (float32 is specified as there is an implicit float32 0.5f constant + // exponent). + XlaOp SqrtF32(const XlaOp& operand); + + // Enqueues a float32 square instruction onto the computation. + // (float32 is specified as there is an implicit float32 2.0f constant + // exponent). + XlaOp SquareF32(const XlaOp& operand); + + // Enqueues a lhs^rhs computation onto the computation. + XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues an operator that tests if the operand's values are finite, i.e., + // not Inf or NaN. Defined only for floating-point types. Returns an array of + // booleans with the same shape where entries are true iff the corresponding + // entry was NaN. + XlaOp IsFinite(const XlaOp& operand); + + // Enqueues a convert instruction onto the computation that changes the + // element type of the operand array to primitive_type. + XlaOp ConvertElementType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + XlaOp BitcastConvertType(const XlaOp& operand, + PrimitiveType new_element_type); + + // Enqueues a float32 reciprocal instruction onto the computation. + // (float32 is specified as there is an implicit float32 -1.0f constant + // exponent). + // + // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the + // shape of the operand. + XlaOp ReciprocalF32(const XlaOp& operand); + + // Enqueues a negate instruction onto the computation. + XlaOp Neg(const XlaOp& operand); + + // Enqueues a transpose instruction onto the computation. + XlaOp Transpose(const XlaOp& operand, + tensorflow::gtl::ArraySlice permutation); + + // Enqueues a reverse instruction onto the computation. The order of the + // elements in the given dimensions is reversed (i.e., the element at index i + // is moved to index dimension_size - 1 - i). + XlaOp Rev(const XlaOp& operand, + tensorflow::gtl::ArraySlice dimensions); + + // Enqueues a sort (as increasing order) instruction onto the computation. + XlaOp Sort(const XlaOp& operand); + + // Enqueues a clamp instruction onto the computation. + XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); + + // Enqueues a map instruction onto the computation. + XlaOp Map(tensorflow::gtl::ArraySlice operands, + const XlaComputation& computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands = {}); + + // Enqueues a N(mu, sigma) random number generation instruction onto the + // computation. + XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); + + // Enqueues a U(a, b) random number generation instruction onto the + // computation. Returns values in the semi-open interval [a, b). + XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); + + // Enqueues a while node onto the computation. + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + const XlaOp& init); + + // Enqueues a conditional node onto the computation. + XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand, + const XlaComputation& true_computation, + const XlaOp& false_operand, + const XlaComputation& false_computation); + + // Enqueues a ReducePrecision node onto the computation. + XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, + const int mantissa_bits); + + // Enqueues a Gather node onto the computation. + XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + + // Enqueues a Send node onto the computation, to send the given operand to + // a Recv instruction that shares the same channel handle. + void Send(const XlaOp& operand, const ChannelHandle& handle); + + // Enqueues a Recv node onto the computation. The data comes from a Send + // instruction that shares the same channel handle and its shape must + // be the same as the given shape. + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + StatusOr IsConstant(const XlaOp& operand) const; + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` + // is the normalized result and batch_mean and batch_var are the mean and + // variance, respectively, across batch for the operand. + XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, float epsilon, + int64 feature_index); + + // Normalizes operand across spatial and batch dimensions for each feature. + // + // `BatchNormInference` is equivalent to calling `BatchNormTraining` without + // computing `mean` and `variance` for each batch inside the operation. It + // uses the input `mean` and `variance` instead as estimated values. The + // purpose of this op is to reduce latency in inference, hence the name + // `BatchNormInference`. + // + // The output has the same shape as `operand`, and contains the normalized + // values for each batch. + XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale, + const XlaOp& offset, const XlaOp& mean, + const XlaOp& variance, float epsilon, + int64 feature_index); + + // Calculates the gradients of a batch norm op. + // + // The inputs `batch_mean` and `batch_var` represent the mean and variance + // across the batch. + // + // Returns a tuple of three elements: + // - grad_operand: Gradient with respect to input `operand` + // - grad_offset: Gradient with respect to input `offset` + // - grad_scale: Gradient with respect to input `scale` + XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, + const XlaOp& batch_mean, const XlaOp& batch_var, + const XlaOp& grad_output, float epsilon, + int64 feature_index); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder(const string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. + StatusOr Build(); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + StatusOr BuildConstantSubGraph(const XlaOp& root_op) const; + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + Status first_error() const { return first_error_; } + + // Returns the shape of the given op. + StatusOr GetShape(const XlaOp& op) const; + + // Returns the (inferred) result for the current computation's shape. + StatusOr GetProgramShape() const; + + private: + StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + tensorflow::gtl::ArraySlice operands = {}); + + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); + + // Notes that the error occurred by: + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to Build()) + // * dying if die_immediately_on_error_ is true + void NoteError(const Status& error); + + XlaOp NoteErrorOrReturn(const std::function()>& op_creator); + + // Helper method that creates an empty op and notes error. + XlaOp UnimplementedOp(); + + StatusOr LookUpInstruction(const XlaOp& op) const; + + // Internal helper method that does the building for an arbitrary unary op. + XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. + XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, + const XlaOp& ehs); + + XlaOp RngOp(RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters, + const Shape& shape); + + StatusOr InDimBroadcast( + const Shape& shape, const XlaOp& operand, + tensorflow::gtl::ArraySlice broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + StatusOr AddBroadcastSequence(const Shape& output_shape, + const XlaOp& operand); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + StatusOr Reshape(const Shape& shape, const XlaOp& operand); + + // Returns the (inferred) result for the program shape for the current + // computation and fills the root_id in the pointer. + StatusOr GetProgramShape(int64* root_id) const; + + // A visitor which checks whether an operation is a compile-time constant, + // meaning that it doesn't depend on any parameters, or on any stateful + // operation such as `RngNormal` or `Infeed`. The visitor walks the + // computation starting at a given operation and sets is_constant to false iff + // a parameter or stateful operation is encountered. + void IsConstantVisitor(const int64 op_handle, std::set* visited, + bool* is_constant) const; + + // Checks bounds for convolution parameters. + Status VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const; + + // Helper function for creating a Window proto from user-supplied data. + // Returns error if the user-supplied data was invalid. + StatusOr MakeWindow( + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation) const; + + string name_; // Name to use for the built computation. + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tensorflow::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + std::vector instructions_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + tensorflow::gtl::FlatSet parameter_numbers_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional sharding_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; +}; + +template +XlaOp XlaBuilder::ConstantR0(NativeT value) { + return ConstantLiteral(*Literal::CreateR0(value)); +} + +template +XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); +} + +inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { + return ConstantLiteral(*Literal::CreateR1(values)); +} + +template +XlaOp XlaBuilder::ConstantR2( + std::initializer_list> values) { + return ConstantLiteral(*Literal::CreateR2(values)); +} + +template +XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, + const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantFromArray(const Array& values) { + return ConstantLiteral(*Literal::CreateFromArray(values)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { + return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D& values) { + return ConstantFromArray(values); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + +template +XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { + return ConstantFromArray(values); +} + +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +// +// TODO(b/74197823): This is a part of a NOT YET ready refactor. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce984564d016ce65fa6c932f3cda290cc0d75a4a --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -0,0 +1,237 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +#include + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +namespace { + +namespace op = xla::testing::opcode_matchers; + +using ::testing::HasSubstr; + +// TODO(b/74197823): Move the tests to service/. +class XlaBuilderTest : public ::testing::Test { + protected: + StatusOr> BuildHloModule(XlaBuilder* b) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, b->Build()); + const HloModuleProto& proto = computation.proto(); + TF_ASSIGN_OR_RETURN(const auto& config, + HloModule::CreateModuleConfigFromProto( + proto, legacy_flags::GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(proto, config); + } + + // Returns the name of the test currently being run. + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } +}; + +TEST_F(XlaBuilderTest, OnePlusTwo) { + XlaBuilder b(TestName()); + b.Add(b.ConstantR0(1.0), b.ConstantR0(2.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Constant(), op::Constant())); +} + +TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x"); + b.Add(x, b.ConstantR0(1.0)); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant()))); +} + +TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { + XlaBuilder b(TestName()); + const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6}); + const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4}); + auto x = b.Parameter(0, x_shape, "x"); + auto y = b.Parameter(1, y_shape, "y"); + auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1)))); +} + +TEST_F(XlaBuilderTest, XPlusX) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x"); + b.Add(x, x); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0))); +} + +TEST_F(XlaBuilderTest, ShapeInferenceError) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference")); +} + +TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) { + XlaBuilder b_call("add"); + b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Add(x, y); + auto statusor = BuildHloModule(&b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("parameter 0 already registered")); +} + +TEST_F(XlaBuilderTest, Call) { + XlaBuilder b_call("the_only_to_apply"); + auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1"); + b_call.Add(p0, p1); + TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build()); + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto one = b.ConstantR0(1); + auto two = b.ConstantR0(2); + b.Add(b.Call(call, {x, y}), b.Call(call, {one, two})); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()), + op::Call(op::Constant(), op::Constant()))); +} + +TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y"); + b.Add(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // Expected: + // + // x: f32[1,2,3] y: f32[1,2,1] + // | | + // | reshape: f32[1,2] + // | | + // | broadcast: f32[1,2,3] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Parameter(0), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y"); + b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + // The binary operation has in-dim broadcast and degenerate broadcast, should + // first do the in-dim broadcast then convert the degnerate broadcast into a + // reshape and a broadcast. + // + // Expected: + // + // x: f32[2,3] y: f32[2,1,4] + // | | + // broadcast: f32[2,3,4] reshape: f32[2,4] + // | | + // | broadcast: f32[2,3,4] + // \ / + // add + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)), + op::Broadcast(op::Reshape(op::Parameter(1))))); +} + +TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { + XlaBuilder b1("b1"); + auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0"); + XlaBuilder builder("main"); + builder.Add(p0, p0); + auto statusor = builder.Build(); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Do not add XlaOp from builder b1 to builder main")); +} + +TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Parameter())); +} + +TEST_F(XlaBuilderTest, ReshapeHasTranspose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x"); + b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter()))); +} + +TEST_F(XlaBuilderTest, Transpose) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + b.Transpose(x, /*permutation=*/{1, 0}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Transpose(op::Parameter())); +} + +// TODO(b/65209188): Create a dedicated lowering for Xor. +TEST_F(XlaBuilderTest, Xor) { + XlaBuilder b(TestName()); + auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x"); + auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y"); + b.Xor(x, y); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + LOG(ERROR) << module->ToString(); + EXPECT_THAT(root, + op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)), + op::And(op::Parameter(0), op::Not(op::Parameter(1))))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6752c601026518825c7994f6b6fa20d20f34f24 --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" + +#include + +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +StatusOr XlaComputation::GetProgramShape() const { + TF_RET_CHECK(proto_.has_program_shape()); + return proto_.program_shape(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h new file mode 100644 index 0000000000000000000000000000000000000000..7ad212aa24cd32d104cc4db7aa164c22c9f5be8f --- /dev/null +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +// +// TODO(b/74197823): Replace xla::Computation with this one. +class XlaComputation { + public: + XlaComputation() : unique_id_(-1) {} + XlaComputation(const HloModuleProto& proto) + : unique_id_(proto.id()), proto_(proto) {} + + ~XlaComputation() {} + + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + StatusOr GetProgramShape() const; + + const HloModuleProto& proto() const { return proto_; } + + // Returns true if this object is a null Computation. + bool IsNull() const { return unique_id_ == -1; } + + private: + XlaComputation(const int64 unique_id) : unique_id_(unique_id) {} + HloModuleProto* mutable_proto() { return &proto_; } + friend class XlaBuilder; + + int64 unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 392ad9010ab81923a089c7b00a79ddc281af92bb..1700c977189a9e4aedf6a6a75923c13678dae667 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -87,4 +87,11 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } +ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { + rng_seed_ = rng_seed; + return *this; +} + +int ExecutableRunOptions::rng_seed() const { return rng_seed_; } + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index d4fcbf0493c936ebcd0639a432e56b62ee15672c..2c1d9ffff10ed26410898ad258aa6b5b2cd37518 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -84,6 +84,9 @@ class ExecutableRunOptions { DeviceAssignment* device_assignment); const DeviceAssignment* device_assignment() const; + ExecutableRunOptions& set_rng_seed(int rng_seed); + int rng_seed() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -92,6 +95,7 @@ class ExecutableRunOptions { tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; + int rng_seed_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 0a9725db0a4fcf963cadcacf2cbc1d95d2c7239d..89353448e29ec3d97275dac288e23aa8e96e31b2 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -75,17 +75,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index c8ed3e3a2b009ddffdfb79a9a6ced8d5e736bee6..70ae95bf47398589e3c20f72c1f2084a738f253a 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -40,7 +40,10 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { flags->set_xla_cpu_multi_thread_eigen(true); flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); flags->set_xla_eliminate_hlo_implicit_broadcast(true); - +#ifdef INTEL_MKL + flags->set_xla_cpu_use_mkl_dnn(true); +#endif // INTEL_MKL + flags->set_xla_gpu_max_kernel_unroll_factor(1); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. flags->set_xla_gpu_use_cudnn_batchnorm(false); @@ -220,6 +223,11 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), + tensorflow::Flag( + "xla_gpu_max_kernel_unroll_factor", + int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), + flag_values->xla_gpu_max_kernel_unroll_factor(), + "Specify the maximum kernel unroll factor for the GPU backend."), tensorflow::Flag( "xla_dump_optimized_hlo_proto_to", flag_values->mutable_xla_dump_optimized_hlo_proto_to(), @@ -288,6 +296,10 @@ void AllocateFlags() { flag_values->xla_gpu_use_cudnn_batchnorm(), "Allows the GPU backend to implement batchnorm HLOs using cudnn, " "rather than expanding them to a soup of HLOs."), + tensorflow::Flag("xla_cpu_use_mkl_dnn", + bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), + flag_values->xla_cpu_use_mkl_dnn(), + "Generate calls to MKL-DNN in the CPU backend."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index a3b4286f4c12bf39a44c63dd6e7d303a46a418c3..7b6ae311c1099dccb8dceb2f49743c1b185cd5ab 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e0a9b148b443e90a0c4f3e19660b6234d49eef84..c315b4ff30059147ee33dcdd5b0858a1c39e5999 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -97,11 +97,18 @@ Literal::Literal(const Shape& shape, bool allocate_arrays) const Shape& subshape = piece.subshape(); if (ShapeUtil::IsArray(subshape)) { if (allocate_arrays) { - piece.set_buffer(new char[piece.size_bytes()]); if (LayoutUtil::IsSparseArray(subshape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(subshape.layout()); + piece.set_buffer( + new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( + subshape.element_type())]); piece.set_sparse_indices(new SparseIndexArray( - LayoutUtil::MaxSparseElements(subshape.layout()), - ShapeUtil::Rank(subshape))); + max_sparse_elements, ShapeUtil::Rank(subshape))); + } else { + piece.set_buffer(new char[piece.size_bytes()]); } } else { piece.set_buffer(nullptr); @@ -223,7 +230,7 @@ Status Literal::CopySliceFromInternal( Literal::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](const std::vector& indexes) { + auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -248,6 +255,28 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } +Status Literal::CopyElementFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index) { + DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); + const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( + src_literal.shape(), src_index); + const int64 dest_linear_index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + char* dest_address = + static_cast(untyped_data()) + dest_linear_index * primitive_size; + const char* source_address = + static_cast(src_literal.untyped_data()) + + src_linear_index * primitive_size; + if (dest_address != source_address) { + memcpy(dest_address, source_address, primitive_size); + } + return Status::OK(); +} + std::vector Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector elements; @@ -343,7 +372,7 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { #undef COPY_ELEMENTS default: return Unimplemented( - "Unhandled primitive type %s", + "Copying a Literal object with element type %s is not implemented.", PrimitiveType_Name(subshape().element_type()).c_str()); } } @@ -491,7 +520,10 @@ Status Literal::CopySliceFrom(const Literal& src_literal, default: break; } - return Unimplemented("Unhandled primitive type %d", shape().element_type()); + return Unimplemented( + "Copying a slice from a Literal object with element type %d is not " + "implemented.", + shape().element_type()); } /* static */ Literal Literal::Zero(PrimitiveType primitive_type) { @@ -808,9 +840,10 @@ std::unique_ptr Literal::Slice( DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { CHECK_GE(start_indices[dnum], 0); - CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)); + CHECK_LE(limit_indices[dnum], shape().dimensions(dnum)) + << "dnum = " << dnum; int64 dimension = limit_indices[dnum] - start_indices[dnum]; - CHECK_GT(dimension, 0); + CHECK_GE(dimension, 0) << "dnum = " << dnum; result_dimensions.push_back(dimension); } const auto result_shape = @@ -903,7 +936,7 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, case U64: return StrCat(Get(multi_index, shape_index)); case F16: - return StrCat(Get(multi_index, shape_index)); + return StrCat(static_cast(Get(multi_index, shape_index))); case F32: return StrCat(Get(multi_index, shape_index)); case BF16: @@ -953,7 +986,8 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, return StrCat( GetSparseElement(sparse_element_number, shape_index)); case F16: - return StrCat(GetSparseElement(sparse_element_number, shape_index)); + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); case F32: return StrCat( GetSparseElement(sparse_element_number, shape_index)); @@ -997,6 +1031,36 @@ StatusOr Literal::GetIntegralAsS64( } } +Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value) { + CHECK(LayoutUtil::IsDenseArray(shape())); + switch (shape().element_type()) { + case PRED: + Set(multi_index, value); + break; + case U8: + Set(multi_index, value); + break; + case S32: + Set(multi_index, value); + break; + case S64: + Set(multi_index, value); + break; + case U32: + Set(multi_index, value); + break; + case U64: + Set(multi_index, value); + break; + default: + return FailedPrecondition( + "Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type()).c_str()); + } + return Status::OK(); +} + tensorflow::gtl::ArraySlice Literal::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); @@ -1009,6 +1073,49 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } +Literal Literal::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape_)); + CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); + switch (shape_.element_type()) { + case PRED: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 8 bit types. + case S8: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U8: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 16 bit types. + case BF16: + return std::move( + *Literal::CreateR0(GetFirstElement())); + case F16: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S16: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U16: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 32 bit types. + case F32: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S32: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U32: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 64 bit types. + case C64: + return std::move( + *Literal::CreateR0(GetFirstElement())); + case F64: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S64: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U64: + return std::move(*Literal::CreateR0(GetFirstElement())); + default: + LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + } +} + void Literal::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: @@ -1285,8 +1392,9 @@ void Literal::EachCellAsString( } namespace { -template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +template +std::unique_ptr ConvertBetweenNativeTypesWithConverter( + const Literal& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1296,11 +1404,40 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); + dest_data[i] = converter(src_data[i]); } return result_literal; } +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + auto converter = [](NativeSrcT src) { return static_cast(src); }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +template +typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const Literal& src_literal) { + auto converter = [](NativeSrcT src) { + return tensorflow::bit_cast(src); + }; + return ConvertBetweenNativeTypesWithConverter( + src_literal, converter); +} + +// This template specialization is here to make the compiler happy. bit_cast has +// a static check that the types are the same size. This specialization should +// never be used because the source and destination types are checked for +// identical sizes higher up. +template +typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), + std::unique_ptr>::type +BitcastBetweenNativeTypes(const Literal& src_literal) { + LOG(FATAL) << "Invalid bitcast between types of different sizes."; +} + template std::unique_ptr ConvertToC64(const Literal& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); @@ -1320,21 +1457,33 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { +std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, + bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return ConvertBetweenNativeTypes< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative< - primitive_dest_type>::type>(src_literal); + if (bitcast) { + return BitcastBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } else { + return ConvertBetweenNativeTypes< + typename primitive_util::PrimitiveTypeToNative< + primitive_src_type>::type, + typename primitive_util::PrimitiveTypeToNative< + primitive_dest_type>::type>(src_literal); + } } template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { + const Literal& src_literal, PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { -#define CONVERT_IF_TYPES_MATCH(type) \ - case (type): \ - return ConvertIfTypesMatch(src_literal); +#define CONVERT_IF_TYPES_MATCH(type) \ + case (type): \ + return ConvertIfTypesMatch(src_literal, \ + bitcast); CONVERT_IF_TYPES_MATCH(PRED) CONVERT_IF_TYPES_MATCH(S8) CONVERT_IF_TYPES_MATCH(S32) @@ -1348,25 +1497,31 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: - return ConvertToC64(src_literal); + if (!bitcast) { + return ConvertToC64(src_literal); + } + break; // Other types are not yet supported. default: - return InvalidArgument( - "Unimplemented: Convert from type %s to type %s", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + break; } + return Unimplemented( + "Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } -} // namespace - -StatusOr> Literal::Convert( - PrimitiveType primitive_dest_type) const { - TF_RET_CHECK(ShapeUtil::IsArray(shape())); - switch (shape().element_type()) { -#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ - case (type): \ - return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type); +StatusOr> ConvertSwitch( + const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); + if (literal.shape().element_type() == primitive_dest_type) { + return literal.CloneToUnique(); + } + switch (literal.shape().element_type()) { +#define CONVERT_IF_DEST_TYPE_MATCHES(type) \ + case (type): \ + return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \ + bitcast); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) CONVERT_IF_DEST_TYPE_MATCHES(S32) @@ -1381,12 +1536,62 @@ StatusOr> Literal::Convert( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return InvalidArgument("Unimplemented: Convert from type %s to type %s", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented( + "%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } +} // namespace + +StatusOr> Literal::Convert( + PrimitiveType primitive_dest_type) const { + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); +} + +StatusOr> Literal::BitcastConvert( + PrimitiveType primitive_dest_type) const { + if (primitive_util::BitWidth(shape().element_type()) != + primitive_util::BitWidth(primitive_dest_type)) { + return InvalidArgument( + "Cannot bitcast convert from %s to %s, bit widths are different: %d != " + "%d", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str(), + primitive_util::BitWidth(shape().element_type()), + primitive_util::BitWidth(primitive_dest_type)); + } + return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); +} + +StatusOr> Literal::ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16) const { + if (!ShapeUtil::IsTuple(dest_shape)) { + if (round_f32_to_bf16 && shape().element_type() == F32 && + dest_shape.element_type() == BF16) { + auto converter = [](float src) { + return tensorflow::bfloat16::round_to_bfloat16(src); + }; + return ConvertBetweenNativeTypesWithConverter(*this, + converter); + } + return Convert(dest_shape.element_type()); + } + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + auto element = LiteralView::Create(*this, {i}); + TF_ASSIGN_OR_RETURN( + auto new_element, + element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); + elements.push_back(std::move(*new_element)); + } + auto converted = MakeUnique(); + *converted = Literal::MoveIntoTuple(&elements); + return std::move(converted); +} + template bool Literal::Piece::EqualElementsInternal( const Literal::Piece& other, std::vector* multi_index) const { @@ -1571,6 +1776,92 @@ bool Literal::IsAllComplex(complex64 value) const { } } +bool Literal::IsAllFirst() const { + for (const auto& pair : pieces_) { + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + } + return true; +} + bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index d996004888ab521790b4c5a10da2a93f0d98d12f..8aa19222dc4b9175ec72128dfdad448f65c23e91 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -262,6 +262,11 @@ class Literal { tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + // Returns a vector containing the tuple elements of this Literal as separate // Literals. This Literal must be tuple-shaped and can be a nested tuple. The // elements are moved into the new Literals; no data is copied. Upon return @@ -328,11 +333,30 @@ class Literal { template std::unique_ptr Replicate(int64 times) const; - // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. This literal must be array-shaped. + // Converts this literal to another primitive type using + // static_cast<>. Returns an error if the conversion is not possible. This + // literal must be array-shaped. StatusOr> Convert( PrimitiveType primitive_dest_type) const; + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -451,6 +475,9 @@ class Literal { template NativeT GetFirstElement() const; + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + // As Get(), but determines the correct type and converts the value // into text. string GetAsString(tensorflow::gtl::ArraySlice multi_index, @@ -466,6 +493,11 @@ class Literal { StatusOr GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const; + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + // Returns an identity matrix (rank 2) with the given row and column count. template static std::unique_ptr MakeIdentityR2(int64 size); @@ -563,6 +595,12 @@ class Literal { template Status Populate(const FnType& generator); + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); + // Fills this literal with the given value. template void PopulateWithValue(NativeT value); @@ -602,6 +640,9 @@ class Literal { // This literal must have a dense layout. bool IsAllComplex(complex64 value) const; + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice indices) const; @@ -700,7 +741,13 @@ class Literal { int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } // Returns the number of elements in this piece's array. - int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } // Copy the data from 'src' into this piece's buffer. Shapes of this piece // and src must be compatible. @@ -758,6 +805,10 @@ class Literal { // buffer). void DeallocateBuffers(); + // Implementation details shared between Populate() and PopulateParallel() + template + Status PopulateInternal(const FnType& generator, bool parallel); + Shape shape_; ShapeTree pieces_; @@ -808,8 +859,7 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), - ShapeUtil::ElementsIn(subshape())); + reinterpret_cast(buffer()), element_count()); } template @@ -822,7 +872,7 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), ShapeUtil::ElementsIn(subshape())); + reinterpret_cast(buffer()), element_count()); } template @@ -1237,19 +1287,20 @@ void Literal::PopulateSparse(SparseIndexArray indices, CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - root_data.remove_suffix(max_elements - values.size()); + // Piece::data() returns an ArraySlice of size equal to the number of indices + // in the SparseIndexArray. So there is no need to adjust the size of the data + // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); *this->root_piece().sparse_indices() = std::move(indices); if (sort) { auto root_data = this->root_piece().data(); - root_data.remove_suffix(root_data.size() - num_elements); this->root_piece().sparse_indices()->SortWithValues(root_data); } DCHECK(this->root_piece().sparse_indices()->Validate(shape())); } template -Status Literal::Populate(const FnType& generator) { +Status Literal::PopulateInternal(const FnType& generator, bool parallel) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); @@ -1259,11 +1310,11 @@ Status Literal::Populate(const FnType& generator) { if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); - DimensionVector minor_scan_indexes(rank, 0); int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](const std::vector& indexes) { + auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + DimensionVector minor_scan_indexes(rank, 0); const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); @@ -1271,17 +1322,35 @@ Status Literal::Populate(const FnType& generator) { minor_scan_indexes[stride_config.minor_dimension] = i; literal_data.at(index + i) = generator(minor_scan_indexes); } - return true; }; - ShapeUtil::ForEachIndex(this_shape, stride_config.base, - stride_config.dimensions, stride_config.step, - init_function); + if (parallel) { + ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base, + stride_config.dimensions, + stride_config.step, init_function); + } else { + ShapeUtil::ForEachIndex( + this_shape, stride_config.base, stride_config.dimensions, + stride_config.step, + [&init_function](tensorflow::gtl::ArraySlice indexes) { + init_function(indexes); + return true; + }); + } } else { // For scalars. literal_data.at(0) = generator({}); } return Status::OK(); } +template +Status Literal::Populate(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/false); +} + +template +Status Literal::PopulateParallel(const FnType& generator) { + return PopulateInternal(generator, /*parallel=*/true); +} template void Literal::PopulateWithValue(NativeT value) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index b3583c2eb75de8297d5e7507430491f119bd4462..61046784e05623cd3117c24ecc6d6c474739bbd5 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -30,6 +31,7 @@ limitations under the License. namespace xla { namespace { +using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -214,11 +216,9 @@ TEST_F(LiteralUtilTest, CreateSparse) { std::vector expected_values = {8, 9, 7, 10}; EXPECT_EQ(literal->sparse_indices()->data(), - tensorflow::gtl::ArraySlice( - expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), - expected_values.size()), - tensorflow::gtl::ArraySlice(expected_values)); + ArraySlice(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal->data(), ArraySlice(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -290,7 +290,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format on std::vector> seen; literal->EachCellAsString( - [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { + [&seen](ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -501,6 +501,24 @@ TEST_F(LiteralUtilTest, IsAllComplex) { ->IsAllComplex({8.0f, 9.0f})); } +TEST_F(LiteralUtilTest, IsAllFirst) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR1({false, true})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({false, false})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0(0.0f); auto scalar_one = Literal::CreateR0(1.0f); @@ -604,11 +622,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { // clang-format on auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, original->Get( - {indices[2], indices[3], indices[0], indices[1]})); - }); + reshape->EachCell([&](ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); + }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { @@ -845,7 +862,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](const std::vector& indexes) { + auto init_proc = [&](ArraySlice indexes) { source->Set(indexes, ++seqnr); return true; }; @@ -861,7 +878,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](const std::vector& indexes) { + auto check_proc = [&](ArraySlice indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); @@ -1049,7 +1066,7 @@ TEST_F(LiteralUtilTest, Populate) { primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); auto literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { + auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1061,7 +1078,49 @@ TEST_F(LiteralUtilTest, Populate) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](const std::vector& indexes) { + auto check_function = [&](ArraySlice indexes) { + auto value = literal->Get(indexes); + matched = matched && (value == generator(indexes)); + return matched; + }; + ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + check_function); + EXPECT_TRUE(matched); + } +} + +TEST_F(LiteralUtilTest, PopulateParallel) { + struct PopulateData { + std::vector dimensions; + std::vector layout; + } populate_data[] = { + {{}, {}}, + {{0}, {0}}, + {{16}, {0}}, + {{2, 0}, {1, 0}}, + {{4, 16}, {1, 0}}, + {{21, 12}, {0, 1}}, + {{6, 11, 17}, {2, 0, 1}}, + {{6, 11, 5, 17}, {3, 2, 0, 1}}, + }; + for (const auto& data : populate_data) { + Shape shape = ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), data.dimensions, + data.layout); + auto literal = Literal::CreateFromShape(shape); + auto generator = [&](ArraySlice indexes) -> uint32 { + // Offsets from linear index just to avoid R0 literals to be initialized + // with zero. + return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + indexes) + + 17; + }; + TF_EXPECT_OK(literal->PopulateParallel(generator)); + + std::vector zero_base(data.dimensions.size(), 0); + std::vector step(data.dimensions.size(), 1); + bool matched = true; + auto check_function = [&](ArraySlice indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; @@ -1214,15 +1273,34 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { EXPECT_EQ(*conv, *c64); EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); EXPECT_EQ(c64->Convert(S32).status().code(), - tensorflow::error::INVALID_ARGUMENT); + tensorflow::error::UNIMPLEMENTED); +} + +TEST_F(LiteralUtilTest, BitcastConvert) { + auto original = + Literal::CreateR1({tensorflow::bit_cast(2.5f), + tensorflow::bit_cast(-42.25f), + tensorflow::bit_cast(100.f), 0xbeef}); + auto expected = Literal::CreateR1( + {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, + original->BitcastConvert(F32)); +} + +TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { + auto literal = Literal::CreateR0(1234); + Status status = literal->BitcastConvert(F64).status(); + EXPECT_NE(Status::OK(), status); + EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), + "bit widths are different")); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1684,7 +1762,7 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(half{2.0})); + tensorflow::strings::StrCat(static_cast(half{2.0}))); ASSERT_EQ( Literal::CreateSparse( dimensions, indices, diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index e2972f06016ab3555c4fc0cc4616993fe6764b1e..0517a5502e686def4ffea59f929aef225186a8aa 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -72,15 +72,3 @@ tf_py_wrap_cc( "//tensorflow/compiler/xla/service:cpu_plugin", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cb7bb21e092c80d7360c23f3d6b00409a75dce23..2bacc6a9142971f6d14b3929fb1a69e2a40052e2 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -368,6 +368,12 @@ ComputationDataHandle LocalComputationBuilder::Slice( return builder_.Slice(operand, start_indices, limit_indices, strides); } +ComputationDataHandle LocalComputationBuilder::SliceInDim( + const ComputationDataHandle& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { + return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +} + ComputationDataHandle LocalComputationBuilder::DynamicSlice( const ComputationDataHandle& operand, const ComputationDataHandle& start_indices, @@ -515,6 +521,17 @@ ComputationDataHandle LocalComputationBuilder::Conditional( false_computation.computation()); } +StatusOr LocalComputationBuilder::IsConstant( + const ComputationDataHandle& operand, int64 num_parameters) { + return builder_.IsConstant(operand, num_parameters); +} + +StatusOr> LocalComputationBuilder::ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { + return builder_.ComputeConstant(operand, output_layout, parameters); +} + #define _FORWARD(method_name, return_sig, args_sig, args) \ return_sig LocalComputationBuilder::method_name args_sig { \ return builder_.method_name args; \ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index d3e9503ea10b011520ec5148a756ef4d421f244c..31046e60f11af9cc89ddec4c5fd16babfc8eb231 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -170,6 +170,10 @@ class LocalComputationBuilder { tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides); + ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + ComputationDataHandle DynamicSlice( const ComputationDataHandle& operand, const ComputationDataHandle& start_indices, @@ -264,6 +268,13 @@ class LocalComputationBuilder { const ComputationDataHandle& false_operand, const LocalComputation& false_computation); + StatusOr IsConstant(const ComputationDataHandle& operand, + int64 num_parameters); + + StatusOr > ComputeConstant( + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters); + #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 456e341f877e529f7fc5ebc81d85862bfa291943..ac792e8189bda9eda472e7d282db86ac988c57b9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -141,6 +141,33 @@ bool GetIntAttr(PyObject* o, const char* field, int64* result) { return true; } +// Returns "ok"; true if there is no error, false if there was an error. +bool HandleStringAttribute(PyObject* o, + const char* attr_name, + std::function f) { + if (!PyObject_HasAttrString(o, attr_name)) { + return true; // It's ok for the object to not have the attribute. + } + PyObject* attr = PyObject_GetAttrString(o, attr_name); + if (attr == nullptr) { + return false; // An error occurred getting the attribute. + } + if (attr == Py_None) { + Py_DECREF(attr); + return true; // The attribute is None, which we consider ok. + } + if (!PyString_Check(attr)) { + string message = tensorflow::strings::Printf("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr).c_str()); + PyErr_SetString(PyExc_TypeError, message.c_str()); + Py_DECREF(attr); + return false; // Type error, not ok. + } + f(PyString_AsString(attr)); + Py_DECREF(attr); + return true; // Handled string attribute, ok! +} + } } %} @@ -155,7 +182,7 @@ tensorflow::ImportNumpy(); %typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { const int64 handle = numpy::PyIntOrPyLongToLong($input); if (handle == -1 && PyErr_Occurred()) { - return NULL; + SWIG_fail; } temp.set_handle(handle); $1 = &temp; @@ -174,7 +201,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -184,7 +211,7 @@ tensorflow::ImportNumpy(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -197,7 +224,7 @@ tensorflow::ImportNumpy(); } } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } } @@ -206,7 +233,16 @@ tensorflow::ImportNumpy(); $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = PyBool_FromLong($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; } } @@ -214,8 +250,9 @@ tensorflow::ImportNumpy(); if (!$1.ok()) { PyErr_SetString( PyExc_RuntimeError, $1.ToString().c_str()); - return NULL; + SWIG_fail; } + Py_INCREF(Py_None); $result = Py_None; } @@ -225,7 +262,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -237,13 +274,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "Argument sequence element cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i] = numpy::PyIntOrPyLongToLong(py_int); if (temps[i] == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } Py_DECREF(py_int); Py_DECREF(o); @@ -257,7 +294,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.resize(size); @@ -268,13 +305,13 @@ tensorflow::ImportNumpy(); PyErr_SetString( PyExc_TypeError, "Argument sequence element cannot be converted to int"); - return NULL; + SWIG_fail; } const int64 handle = numpy::PyIntOrPyLongToLong(py_int); if (handle == -1 && PyErr_Occurred()) { Py_DECREF(py_int); Py_DECREF(o); - return NULL; + SWIG_fail; } temps[i].set_handle(handle); Py_DECREF(py_int); @@ -289,7 +326,7 @@ tensorflow::ImportNumpy(); (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); @@ -298,7 +335,7 @@ tensorflow::ImportNumpy(); LocalShapedBuffer* lsbp; if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), SWIG_POINTER_EXCEPTION)) == -1) { - return NULL; + SWIG_fail; } temps.push_back(lsbp); Py_DECREF(o); @@ -312,7 +349,7 @@ tensorflow::ImportNumpy(); literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); - return NULL; + SWIG_fail; } $1 = literal_status.ValueOrDie().get(); } @@ -324,7 +361,7 @@ tensorflow::ImportNumpy(); %typemap(out) StatusOr< std::unique_ptr > { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - return NULL; + SWIG_fail; } $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); } @@ -332,7 +369,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector& (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -341,7 +378,7 @@ tensorflow::ImportNumpy(); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); Py_DECREF(o); @@ -355,7 +392,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::OpMetadataFromPyObject($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -367,7 +404,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -382,7 +419,7 @@ tensorflow::ImportNumpy(); StatusOr statusor = numpy::XlaShapeFromPyShape($input); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temp = std::move(statusor).ValueOrDie(); $1 = &temp; @@ -396,7 +433,7 @@ tensorflow::ImportNumpy(); %typemap(in) const std::vector& (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -405,7 +442,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -416,7 +453,7 @@ tensorflow::ImportNumpy(); std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { @@ -428,7 +465,7 @@ tensorflow::ImportNumpy(); Py_DECREF(o); if (!statusor.ok()) { PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); - return NULL; + SWIG_fail; } temps.push_back(statusor.ConsumeValueOrDie()); } @@ -442,18 +479,18 @@ tensorflow::ImportNumpy(); PyObject* py_int = numpy::PyNumberToPyInt($input); if (!py_int) { PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); - return NULL; + SWIG_fail; } const long value = numpy::PyIntOrPyLongToLong(py_int); if (value == -1 && PyErr_Occurred()) { Py_DECREF(py_int); - return NULL; + SWIG_fail; } if (!PrimitiveType_IsValid(value)) { PyErr_SetString( PyExc_TypeError, "Argument not valid for PrimitiveType enum"); Py_DECREF(py_int); - return NULL; + SWIG_fail; } $1 = static_cast(value); } @@ -464,19 +501,19 @@ tensorflow::ImportNumpy(); (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); - return NULL; + SWIG_fail; } const int size = PySequence_Size($input); temps.reserve(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); if (!o) { - return NULL; + SWIG_fail; } PyObject* first = PyTuple_GetItem(o, 0); if (!first) { Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* first_pyint = numpy::PyNumberToPyInt(first); if (!first_pyint) { @@ -484,13 +521,13 @@ tensorflow::ImportNumpy(); PyExc_TypeError, "First pair item cannot be converted to int"); Py_DECREF(o); - return NULL; + SWIG_fail; } PyObject* second = PyTuple_GetItem(o, 1); if (!second) { Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } PyObject* second_pyint = numpy::PyNumberToPyInt(second); if (!second_pyint) { @@ -499,21 +536,21 @@ tensorflow::ImportNumpy(); "Second pair item cannot be converted to int"); Py_DECREF(o); Py_DECREF(first_pyint); - return NULL; + SWIG_fail; } const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); if (first_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); if (second_value == -1 && PyErr_Occurred()) { Py_DECREF(o); Py_DECREF(first_pyint); Py_DECREF(second_pyint); - return NULL; + SWIG_fail; } temps.push_back(std::make_pair(first_value, second_value)); Py_DECREF(o); @@ -531,26 +568,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( $input, "lhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_contracting_dimensions); if (length == -1) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); if (!item) { Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -561,26 +598,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( $input, "rhs_contracting_dimensions"); if (!lhs_contracting_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_contracting_dimensions); if (length == -1) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); if (!item) { Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_contracting_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_contracting_dimensions(dimension); Py_DECREF(item); @@ -591,26 +628,26 @@ tensorflow::ImportNumpy(); PyObject* lhs_batch_dimensions = PyObject_GetAttrString( $input, "lhs_batch_dimensions"); if (!lhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(lhs_batch_dimensions); if (length == -1) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); if (!item) { Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(lhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_lhs_batch_dimensions(dimension); Py_DECREF(item); @@ -621,26 +658,26 @@ tensorflow::ImportNumpy(); PyObject* rhs_batch_dimensions = PyObject_GetAttrString( $input, "rhs_batch_dimensions"); if (!rhs_batch_dimensions) { - return NULL; + SWIG_fail; } length = PySequence_Size(rhs_batch_dimensions); if (length == -1) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); if (!item) { Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(rhs_batch_dimensions); - return NULL; + SWIG_fail; } dimension_numbers.add_rhs_batch_dimensions(dimension); Py_DECREF(item); @@ -656,20 +693,20 @@ tensorflow::ImportNumpy(); (PaddingConfig padding_config) { PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); if (!dimensions) { - return NULL; + SWIG_fail; } int length = PySequence_Size(dimensions); if (length == -1) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(dimensions, i); if (!item) { Py_DECREF(dimensions); - return NULL; + SWIG_fail; } int64 edge_padding_low, edge_padding_high, interior_padding; if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) @@ -677,7 +714,7 @@ tensorflow::ImportNumpy(); || !GetIntAttr(item, "interior_padding", &interior_padding)) { Py_DECREF(item); Py_DECREF(dimensions); - return NULL; + SWIG_fail; } Py_DECREF(item); @@ -699,32 +736,32 @@ tensorflow::ImportNumpy(); int64 value; if (!GetIntAttr($input, "input_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_batch_dimension(value); if (!GetIntAttr($input, "input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_input_feature_dimension(value); if (!GetIntAttr($input, "output_batch_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_batch_dimension(value); if (!GetIntAttr($input, "output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_output_feature_dimension(value); if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { - return NULL; + SWIG_fail; } dimension_numbers.set_kernel_input_feature_dimension(value); @@ -733,24 +770,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "input_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_input_spatial_dimensions(dimension); Py_DECREF(item); @@ -759,24 +796,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_kernel_spatial_dimensions(dimension); Py_DECREF(item); @@ -785,24 +822,24 @@ tensorflow::ImportNumpy(); o = PyObject_GetAttrString($input, "output_spatial_dimensions"); if (!o) { - return NULL; + SWIG_fail; } length = PySequence_Size(o); if (length == -1) { Py_DECREF(o); - return NULL; + SWIG_fail; } for (int i = 0; i < length; ++i) { PyObject* item = PySequence_GetItem(o, i); if (!item) { Py_DECREF(o); - return NULL; + SWIG_fail; } const int64 dimension = numpy::PyIntOrPyLongToLong(item); if (dimension == -1 && PyErr_Occurred()) { Py_DECREF(item); Py_DECREF(o); - return NULL; + SWIG_fail; } dimension_numbers.add_output_spatial_dimensions(dimension); Py_DECREF(item); @@ -819,16 +856,32 @@ tensorflow::ImportNumpy(); if ($input == Py_None) { $1 = NULL; } else { - PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); - if (!o) { - return NULL; + if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) { + build_options.set_generate_hlo_graph(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) { + build_options.set_dump_optimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { + build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } + + PyObject* o = PyObject_GetAttrString($input, "hlo_profile"); + if (o == NULL) { + SWIG_fail; } if (o != Py_None) { - if (!PyString_Check(o)) { - PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); - return NULL; + if (!PyBool_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None."); + SWIG_fail; } - build_options.set_generate_hlo_graph(PyString_AsString(o)); + build_options.set_hlo_profile(o == Py_True); } Py_DECREF(o); @@ -841,7 +894,7 @@ tensorflow::ImportNumpy(); if (!statusor.ok()) { PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); - return NULL; + SWIG_fail; } build_options.set_result_layout(statusor.ValueOrDie()); } @@ -886,6 +939,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Collapse; %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; %unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::SliceInDim; %unignore xla::swig::LocalComputationBuilder::DynamicSlice; %unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; %unignore xla::swig::LocalComputationBuilder::ConcatInDim; @@ -906,6 +960,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::RngBernoulli; %unignore xla::swig::LocalComputationBuilder::While; %unignore xla::swig::LocalComputationBuilder::Conditional; +%unignore xla::swig::LocalComputationBuilder::IsConstant; %unignore xla::swig::LocalComputationBuilder::Eq; %unignore xla::swig::LocalComputationBuilder::Ne; %unignore xla::swig::LocalComputationBuilder::Ge; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 3d87480728aab1d4ebbc71c6c7504d37cae5edaf..eec48479c929ab0823fef342fc284bfdc4b1f339 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -170,8 +170,7 @@ static string PyObjectCppStr(PyObject* o) { return ExtractStringAndDecref(s); } -// Safely returns a repr of the given Python object o as a C++ string. -static string PyObjectCppRepr(PyObject* o) { +string PyObjectCppRepr(PyObject* o) { PyObject* r = PyObject_Repr(o); return ExtractStringAndDecref(r); } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index adfcc3b8588dce01718bb19dea936bace483be4d..9656cb1c31c39dbe54293700c2765d0723255657 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -107,6 +107,9 @@ void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { std::copy(source.begin(), source.end(), dest); } +// Safely returns a repr of the given Python object o as a C++ string. +string PyObjectCppRepr(PyObject* o); + // Workarounds for Python 2 and 3 interop PyObject* LongToPyIntOrPyLong(long x); // NOLINT diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 9bda9d09294bc75acaa35d8e4a512820046e8920..9c81f6439d0d9f0a0f0d1d3402e9c1ada46e8691 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -30,9 +30,9 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api -# Most functions are snake_case for consistency with other modules, -# whereas method names of ComputationBuilder and LocalComputation are -# CamelCase for consistency with XLA. +# Most functions are snake_case for consistency with other modules, whereas +# method names of ComputationBuilder and LocalComputation are CamelCase for +# consistency with XLA. # pylint: disable=invalid-name @@ -123,24 +123,34 @@ _BINARY_OPS = [ 'Pow', ] + XLA_ELEMENT_TYPE_TO_DTYPE = { - xla_data_pb2.F32: np.dtype(np.float32), - xla_data_pb2.F64: np.dtype(np.float64), - xla_data_pb2.S32: np.dtype(np.int32), - xla_data_pb2.S64: np.dtype(np.int64), - xla_data_pb2.U32: np.dtype(np.uint32), - xla_data_pb2.U64: np.dtype(np.uint64), - xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.PRED: np.dtype('bool'), + xla_data_pb2.S8: np.dtype('int8'), + xla_data_pb2.S16: np.dtype('int16'), + xla_data_pb2.S32: np.dtype('int32'), + xla_data_pb2.S64: np.dtype('int64'), + xla_data_pb2.U8: np.dtype('uint8'), + xla_data_pb2.U16: np.dtype('uint16'), + xla_data_pb2.U32: np.dtype('uint32'), + xla_data_pb2.U64: np.dtype('uint64'), + xla_data_pb2.F16: np.dtype('float16'), + xla_data_pb2.F32: np.dtype('float32'), + xla_data_pb2.F64: np.dtype('float64'), + xla_data_pb2.C64: np.dtype('complex64'), xla_data_pb2.TUPLE: np.dtype(np.object), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(v): k - for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} +DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et + for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] class LocalBuffer(object): @@ -310,6 +320,9 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None + self.dump_optimized_hlo_proto_to = None + self.dump_per_pass_hlo_proto_to = None + self.hlo_profile = False def transfer_to_infeed(value, replica_number=None): @@ -656,7 +669,7 @@ class ComputationBuilder(object): representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added pad op. + A ComputationDataHandle representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) @@ -666,7 +679,20 @@ class ComputationBuilder(object): padding_config)) def Reshape(self, operand, dimensions, new_sizes): - """Reshape op.""" + """Enqueues a reshape op onto the computation. + + Args: + operand: ComputationDataHandle representing the array to be reshaped. + dimensions: sequence of integers encoding the order in which dimensions + are collapsed or None, in which case dimensions are flattened in order. + new_sizes: sequence of integers encoding the new dimension sizes (shape). + + Returns: + A ComputationDataHandle representing the added Reshape op. + """ + if dimensions is None: + ndim = len(self.GetShape(operand).dimensions()) + dimensions = tuple(range(ndim)) return _wrap_data_handle( self._client.Reshape( _unwrap_data_handle(operand), dimensions, new_sizes)) @@ -772,11 +798,27 @@ class ComputationBuilder(object): strides = [1] * len(start_indices) return _wrap_data_handle( self._client.Slice( - _unwrap_data_handle(operand), - start_indices, - limit_indices, + _unwrap_data_handle(operand), start_indices, limit_indices, strides)) + def SliceInDim(self, operand, start_index, limit_index, stride, dimno): + """Enqueues a slice-in-dimension operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_index: an integer containing the start index of the slice. + limit_index: an integer containing the end index of the slice. + stride: an integer containing the stride size for the slice. + dimno: an integer indicating the dimension along which to slice. + + Returns: + A ComputationDataHandle representing the added Slice op. + """ + return _wrap_data_handle( + self._client.SliceInDim( + _unwrap_data_handle(operand), start_index, limit_index, stride, + dimno)) + def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. @@ -986,6 +1028,20 @@ class ComputationBuilder(object): _unwrap_data_handle(false_operand), false_computation.c_local_computation)) + def IsConstant(self, operand, num_parameters=0): + """Enqueues an IsConstant operation onto the computation. + + Args: + operand: a ComputationDataHandle to test. + num_parameters: optional int, number of computation parameters to treat as + constant (default 0). + + Returns: bool indicating whether `operand` is a compile-time constant, + meaning its value does not depend on parameters with index greater than or + equal to `num_parameters`. + """ + return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c9d09cd5d57e001fd48d2dba9f2b0ee18374231b..433ea568776102539d4cfe7364eedafe22caf6ac 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -762,6 +762,23 @@ class SingleOpTest(LocalComputationTest): [3, 2]) self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + def testSliceInDim(self): + c = self._NewComputation() + c.SliceInDim( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) + c.SliceInDim( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) + def testDynamicSlice(self): c = self._NewComputation() c.DynamicSlice( @@ -838,6 +855,17 @@ class SingleOpTest(LocalComputationTest): self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) + def testIsConstant(self): + c = self._NewComputation() + a = c.ConstantS32Scalar(3) + b = c.ConstantS32Scalar(1) + x = c.ParameterFromNumpy(NumpyArrayS32(0)) + const_expr = c.Sub(b, a) + non_const_expr = c.Mul(const_expr, x) + self.assertTrue(c.IsConstant(const_expr)) + self.assertFalse(c.IsConstant(non_const_expr)) + # self.assertTrue(c.IsConstant(c.Sub(c.Add(x, a), x))) # TODO(b/77245564) + class EmbeddedComputationsTest(LocalComputationTest): """Tests for XLA graphs with embedded computations (such as maps).""" @@ -1132,7 +1160,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose( c, expected=np.sum(input_array, axis=tuple(dims))) - _ReduceAndTest(0) _ReduceAndTest(0) _ReduceAndTest(0, 1) _ReduceAndTest(0, 2) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a9acdae380af5b7f9efb3d08302fc717108f5e40..ad3a28e11939d6259ebd75d544a950ba7abd741f 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -30,29 +30,23 @@ limitations under the License. namespace xla { -/* static */ std::unique_ptr> ReferenceUtil::TransposeArray2D( - const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); - for (int64 w = 0; w < operand.width(); ++w) { - for (int64 h = 0; h < operand.height(); ++h) { - (*result)(w, h) = operand(h, w); - } - } - - return result; -} - -/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( - const Array2D& lhs, const Array2D& rhs) { +namespace { + +template +std::unique_ptr> MatmulArray2DImpl( + const Array2D& lhs, const Array2D& rhs, + const std::function& impl_fn) { CHECK_EQ(lhs.width(), rhs.height()); int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = MakeUnique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). - __xla_cpu_runtime_EigenSingleThreadedMatMulF32( + impl_fn( /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, k, /*transpose_lhs=*/0, @@ -60,22 +54,24 @@ namespace xla { return result; } +} // namespace + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF16); +} + +/* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( + const Array2D& lhs, const Array2D& rhs) { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF32); +} + /* static */ std::unique_ptr> ReferenceUtil::MatmulArray2D( const Array2D& lhs, const Array2D& rhs) { - CHECK_EQ(lhs.width(), rhs.height()); - int m = lhs.height(); - int n = rhs.width(); - int k = lhs.width(); - auto result = MakeUnique>(m, n); - // Because Eigen is a header-oriented library, make sure that the Eigen code - // is the same as the code used by the CPU backend (otherwise the linker will - // randomly pick *some* definition). - __xla_cpu_runtime_EigenSingleThreadedMatMulF64( - /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, - k, - /*transpose_lhs=*/0, - /*transpose_rhs=*/0); - return result; + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF64); } /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( @@ -188,18 +184,6 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); } -/* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, - const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { - std::vector dim_lengths{static_cast(operand.size())}; - return ReduceWindow1DGeneric( - operand, init, reduce_func, window, stride, - xla::MakePadding(dim_lengths, window, stride, padding)); -} - /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& operand, float init, @@ -239,23 +223,28 @@ ReferenceUtil::ReduceWindow1DAdd( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; - return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, - padding); + std::vector dim_lengths{static_cast(operand.size())}; + return ReduceWindow1DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); } -/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, + const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { std::vector dim_lengths{operand.height(), operand.width()}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0], window_counts[1]); @@ -271,7 +260,7 @@ ReferenceUtil::ReduceWindow1DAdd( if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && i0_base + i0_win < operand.n1() && i1_base + i1_win < operand.n2()) { - val += operand(i0_base + i0_win, i1_base + i1_win); + val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win)); } } } @@ -281,6 +270,17 @@ ReferenceUtil::ReduceWindow1DAdd( return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( + const Array2D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; + std::vector dim_lengths{operand.height(), operand.width()}; + return ReduceWindow2DGeneric( + operand, init, add_reduce, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -472,7 +472,7 @@ ReferenceUtil::SelectAndScatter4DGePlus( i3_base + i3_win < operand.n4()) { float tmp = operand(i0_base + i0_win, i1_base + i1_win, i2_base + i2_win, i3_base + i3_win); - if (tmp >= val) { + if (tmp > val) { val = tmp; scatter_0 = i0_base + i0_win; scatter_1 = i1_base + i1_win; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 3ec96f2f38b8f91e1549419b60481327fa9bbd5f..28d6a8c3fe85fa4179bf2f41c82ad4eb93a045fe 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -39,10 +39,22 @@ namespace xla { class ReferenceUtil { public: // Returns the result of a transpose operation on the input matrix. - static std::unique_ptr> TransposeArray2D( - const Array2D& operand); + template + static std::unique_ptr> TransposeArray2D( + const Array2D& operand) { + auto result = MakeUnique>(operand.width(), operand.height()); + for (int64 w = 0; w < operand.width(); ++w) { + for (int64 h = 0; h < operand.height(); ++h) { + (*result)(w, h) = operand(h, w); + } + } + + return result; + } // Returns the result of a matrix multiply `lhs x rhs`. + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); static std::unique_ptr> MatmulArray2D( @@ -187,9 +199,10 @@ class ReferenceUtil { const tensorflow::gtl::ArraySlice& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); - static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); + static std::unique_ptr> ReduceWindow2DGeneric( + const Array2D& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, @@ -215,6 +228,7 @@ class ReferenceUtil { // Performs select and scatter with Greater Than or equal as the select, plus // as the scatter, and Same Padding. + // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, const tensorflow::gtl::ArraySlice& window, diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..977f8637873a4b6555798f533010a28ff36e8679 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -0,0 +1,79 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/compiler/xla:xla.bzl", + "xla_proto_library", + "xla_py_grpc_library", +) + +xla_proto_library( + name = "xla_service_proto", + srcs = ["xla_service.proto"], + use_grpc_plugin = True, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + ], +) + +cc_library( + name = "grpc_stub", + srcs = ["grpc_stub.cc"], + hdrs = ["grpc_stub.h"], + deps = [ + ":xla_service_proto", + "//tensorflow/compiler/xla:service_interface", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + ], +) + +tf_cc_binary( + name = "grpc_service_main_cpu", + srcs = ["grpc_service_main.cc"], + deps = [ + ":grpc_service", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@grpc//:grpc++_unsecure", + ], +) + +tf_cc_test( + name = "grpc_client_test", + srcs = ["grpc_client_test.cc"], + data = [ + "//tensorflow/compiler/xla/rpc:grpc_service_main_cpu", + ], + deps = [ + ":grpc_stub", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@grpc//:grpc++_unsecure", + ], +) + +cc_library( + name = "grpc_service", + srcs = ["grpc_service.cc"], + hdrs = ["grpc_service.h"], + deps = [ + ":xla_service_proto", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "@grpc//:grpc++_unsecure", + ], +) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b559ee4b5a345dbb2cc481b571562a0a630b3294 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple C++ test to exercise the GRPC capabilities of XLA. +// +// Launches an RPC service in a subprocess and connects to it over a socket +// using an RPCStub. +#include +#include + +#include "grpc++/create_channel.h" +#include "grpc++/security/credentials.h" + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/rpc/grpc_stub.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/net.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class GRPCClientTestBase : public ::testing::Test { + protected: + GRPCClientTestBase() { + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + string service_main_path = tensorflow::io::JoinPath( + test_srcdir, "compiler/xla/rpc/grpc_service_main_cpu"); + int port = tensorflow::internal::PickUnusedPortOrDie(); + subprocess_.SetProgram( + service_main_path, + {service_main_path, tensorflow::strings::Printf("--port=%d", port)}); + subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT, + tensorflow::ACTION_DUPPARENT); + subprocess_.SetChannelAction(tensorflow::CHAN_STDERR, + tensorflow::ACTION_DUPPARENT); + CHECK(subprocess_.Start()); + LOG(INFO) << "Launched subprocess"; + + auto channel = + ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port), + ::grpc::InsecureChannelCredentials()); + channel->WaitForConnected(gpr_time_add( + gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); + LOG(INFO) << "Channel to server is connected on port " << port; + + xla_service_ = grpc::XlaService::NewStub(channel); + stub_.reset(new GRPCStub(xla_service_.get())); + client_.reset(new Client(stub_.get())); + } + + ~GRPCClientTestBase() override { + LOG(INFO) << "Killing subprocess"; + subprocess_.Kill(SIGKILL); + } + + tensorflow::SubProcess subprocess_; + std::unique_ptr xla_service_; + std::unique_ptr stub_; + std::unique_ptr client_; +}; + +TEST_F(GRPCClientTestBase, ItsAlive) { + ASSERT_NE(xla_service_, nullptr); + ASSERT_NE(stub_, nullptr); + ASSERT_NE(client_, nullptr); +} + +TEST_F(GRPCClientTestBase, AxpyTenValues) { + ComputationBuilder builder(client_.get(), "axpy_10"); + auto alpha = builder.ConstantR0(3.1415926535); + auto x = builder.ConstantR1( + {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); + auto y = builder.ConstantR1( + {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); + auto ax = builder.Mul(alpha, x); + auto axpy = builder.Add(ax, y); + + std::vector expected = { + 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, + 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; + std::unique_ptr expected_literal = + Literal::CreateR1(expected); + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( + computation, {}, nullptr)); + LiteralTestUtil::ExpectNear(*expected_literal, *result_literal, + ErrorSpec(0.0001)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc new file mode 100644 index 0000000000000000000000000000000000000000..414829d6e76354672c7c1998d1fb1bd185043d78 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -0,0 +1,192 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/rpc/grpc_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace xla { + +/* static */ StatusOr> GRPCService::NewService( + perftools::gputools::Platform* platform) { + std::unique_ptr grpc_service(new GRPCService()); + TF_ASSIGN_OR_RETURN(grpc_service->service_, + ::xla::Service::NewService(platform)); + return std::move(grpc_service); +} + +::grpc::Status DelegateRPC(std::function op) { + tensorflow::Status s = op(); + return tensorflow::ToGrpcStatus(s); +} + +::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, + const ComputationRequest* arg, + ComputationResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Computation(arg, result); }); +} + +::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, + const OpRequest* arg, OpResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Op(arg, result); }); +} + +::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, + const UnregisterRequest* arg, + UnregisterResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Unregister(arg, result); }); +} + +::grpc::Status GRPCService::DeconstructTuple(::grpc::ServerContext* context, + const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->DeconstructTuple(arg, result); + }); +} + +::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, + const SetReturnValueRequest* arg, + SetReturnValueResponse* results) { + return DelegateRPC([this, arg, results]() { + return service_->SetReturnValue(arg, results); + }); +} + +::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, + const ExecuteRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->Execute(arg, result); }); +} + +::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, + const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); +} + +::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->WaitForExecution(arg, result); + }); +} + +::grpc::Status GRPCService::TransferToClient(::grpc::ServerContext* context, + const TransferToClientRequest* arg, + TransferToClientResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferToClient(arg, result); + }); +} + +::grpc::Status GRPCService::TransferToServer(::grpc::ServerContext* context, + const TransferToServerRequest* arg, + TransferToServerResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferToServer(arg, result); + }); +} + +::grpc::Status GRPCService::TransferToInfeed(::grpc::ServerContext* context, + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferToInfeed(arg, result); + }); +} + +::grpc::Status GRPCService::TransferFromOutfeed( + ::grpc::ServerContext* context, const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->TransferFromOutfeed(arg, result); + }); +} + +::grpc::Status GRPCService::ResetDevice(::grpc::ServerContext* context, + const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ResetDevice(arg, result); }); +} + +::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, + const IsConstantRequest* arg, + IsConstantResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->IsConstant(arg, result); }); +} + +::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, + const ComputeConstantRequest* arg, + ComputeConstantResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ComputeConstant(arg, result); }); +} + +::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, + const GetShapeRequest* arg, + GetShapeResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->GetShape(arg, result); }); +} + +::grpc::Status GRPCService::GetComputationShape( + ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->GetComputationShape(arg, result); + }); +} + +::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, + const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->GetLocalShape(arg, result); }); +} + +::grpc::Status GRPCService::GetComputationStats( + ::grpc::ServerContext* context, const ComputationStatsRequest* arg, + ComputationStatsResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->GetComputationStats(arg, result); + }); +} + +::grpc::Status GRPCService::SnapshotComputation( + ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->SnapshotComputation(arg, result); + }); +} + +::grpc::Status GRPCService::LoadComputationSnapshot( + ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, + LoadComputationSnapshotResponse* result) { + return DelegateRPC([this, arg, result]() { + return service_->LoadComputationSnapshot(arg, result); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h new file mode 100644 index 0000000000000000000000000000000000000000..7c9e484517e9ced45c40dda78a2bd427a24c2722 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ + +#include "grpc++/server_context.h" +#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" +#include "tensorflow/compiler/xla/service/service.h" + +namespace xla { + +// Service implementation which wraps a XLA Service with a GRPC interface. +class GRPCService : public grpc::XlaService::Service { + public: + // Factory for creating a RPCService. The parameter platform is the platform + // that the service should target. If platform is null then the default + // platform is used. + static StatusOr> NewService( + perftools::gputools::Platform* platform = nullptr); + + ::grpc::Status Computation(::grpc::ServerContext* context, + const ComputationRequest* arg, + ComputationResponse* result) override; + + ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, + OpResponse* result) override; + + ::grpc::Status Unregister(::grpc::ServerContext* context, + const UnregisterRequest* arg, + UnregisterResponse* result) override; + + ::grpc::Status DeconstructTuple(::grpc::ServerContext* context, + const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; + + ::grpc::Status SetReturnValue(::grpc::ServerContext* context, + const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; + + ::grpc::Status Execute(::grpc::ServerContext* context, + const ExecuteRequest* arg, + ExecuteResponse* result) override; + + ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, + const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; + + ::grpc::Status WaitForExecution(::grpc::ServerContext* context, + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; + + ::grpc::Status TransferToClient(::grpc::ServerContext* context, + const TransferToClientRequest* arg, + TransferToClientResponse* result) override; + + ::grpc::Status TransferToServer(::grpc::ServerContext* context, + const TransferToServerRequest* arg, + TransferToServerResponse* result) override; + + ::grpc::Status TransferToInfeed(::grpc::ServerContext* context, + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; + + ::grpc::Status TransferFromOutfeed( + ::grpc::ServerContext* context, const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; + + ::grpc::Status ResetDevice(::grpc::ServerContext* context, + const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; + + ::grpc::Status IsConstant(::grpc::ServerContext* context, + const IsConstantRequest* arg, + IsConstantResponse* result) override; + + ::grpc::Status ComputeConstant(::grpc::ServerContext* context, + const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; + + ::grpc::Status GetShape(::grpc::ServerContext* context, + const GetShapeRequest* arg, + GetShapeResponse* result) override; + + ::grpc::Status GetComputationShape( + ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; + + ::grpc::Status GetLocalShape(::grpc::ServerContext* context, + const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; + + ::grpc::Status GetComputationStats(::grpc::ServerContext* context, + const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; + + ::grpc::Status SnapshotComputation( + ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) override; + + ::grpc::Status LoadComputationSnapshot( + ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, + LoadComputationSnapshotResponse* result) override; + + private: + std::unique_ptr<::xla::Service> service_; + + GRPCService() {} + GRPCService(const GRPCService&) = delete; + void operator=(const GRPCService&) = delete; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..e29908ccec80db76e3b5b856e57382c56430c379 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic server binary that exposes a xla::Service through a GRPC interface +// on a configurable port. +#include "grpc++/security/server_credentials.h" +#include "grpc++/server.h" +#include "grpc++/server_builder.h" +#include "tensorflow/compiler/xla/rpc/grpc_service.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace xla { +namespace { + +int RealMain(int argc, char** argv) { + int32 port = 1685; + std::vector flag_list = { + tensorflow::Flag("port", &port, "port to listen on"), + }; + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parsed_values_ok) { + LOG(ERROR) << usage; + return 2; + } + tensorflow::port::InitMain(argv[0], &argc, &argv); + + std::unique_ptr service = + xla::GRPCService::NewService().ConsumeValueOrDie(); + + ::grpc::ServerBuilder builder; + string server_address(tensorflow::strings::Printf("localhost:%d", port)); + + builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); + builder.RegisterService(service.get()); + std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); + + LOG(INFO) << "Server listening on " << server_address; + server->Wait(); + + return 0; +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { return xla::RealMain(argc, argv); } diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc new file mode 100644 index 0000000000000000000000000000000000000000..e1f2b0abe39b10dd82b700941748bc4f4e8cb2f8 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -0,0 +1,244 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/rpc/grpc_stub.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace xla { + +GRPCStub::~GRPCStub() = default; + +tensorflow::Status MakeRPC( + const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { + ::grpc::ClientContext context; + ::grpc::Status s = rpc_method(&context); + return tensorflow::FromGrpcStatus(s); +} + +tensorflow::Status GRPCStub::TransferToClient( + const TransferToClientRequest* request, + TransferToClientResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferToClient(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::TransferToServer( + const TransferToServerRequest* request, + TransferToServerResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferToServer(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::TransferToInfeed( + const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferToInfeed(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::TransferFromOutfeed( + const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->TransferFromOutfeed(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ResetDevice(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::LoadComputationSnapshot( + const LoadComputationSnapshotRequest* request, + LoadComputationSnapshotResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->LoadComputationSnapshot(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Execute(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ExecuteGraph(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ExecuteParallel( + const ExecuteParallelRequest* request, ExecuteParallelResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ExecuteParallel(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ExecuteGraphParallel( + const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ExecuteGraphParallel(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, + ExecuteAsyncResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ExecuteAsync(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::WaitForExecution( + const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->WaitForExecution(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::DeconstructTuple( + const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->DeconstructTuple(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::GetComputationStats( + const ComputationStatsRequest* request, + ComputationStatsResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetComputationStats(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::GetComputationGraphStats( + const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetComputationGraphStats(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::GetComputationShape( + const GetComputationShapeRequest* request, + GetComputationShapeResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetComputationShape(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetShape(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::GetDeviceHandles( + const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetDeviceHandles(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::CreateChannelHandle( + const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->CreateChannelHandle(context, *request, response); + }); +} + +// Methods used by ComputationBuilder. +tensorflow::Status GRPCStub::Computation(const ComputationRequest* request, + ComputationResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Computation(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::Op(const OpRequest* request, + OpResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->CreateOp(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, + GetLocalShapeResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->GetLocalShape(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::SetReturnValue( + const SetReturnValueRequest* request, SetReturnValueResponse* responses) { + return MakeRPC([this, request, responses](::grpc::ClientContext* context) { + return grpc_stub_->SetReturnValue(context, *request, responses); + }); +} + +tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request, + IsConstantResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->IsConstant(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ComputeConstant( + const ComputeConstantRequest* request, ComputeConstantResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ComputeConstant(context, *request, response); + }); +} + +tensorflow::Status GRPCStub::ComputeConstantGraph( + const ComputeConstantGraphRequest* request, + ComputeConstantResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->ComputeConstantGraph(context, *request, response); + }); +} + +// Methods used by Computation. +tensorflow::Status GRPCStub::SnapshotComputation( + const SnapshotComputationRequest* request, + SnapshotComputationResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->SnapshotComputation(context, *request, response); + }); +} + +// Methods used by GlobalData. +tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { + return MakeRPC([this, request, response](::grpc::ClientContext* context) { + return grpc_stub_->Unregister(context, *request, response); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h new file mode 100644 index 0000000000000000000000000000000000000000..fd9810d4f1a5e084b73e83007ea7f9f8b0462c72 --- /dev/null +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_STUB_H_ +#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_STUB_H_ + +#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" +#include "tensorflow/compiler/xla/service_interface.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { + +class GRPCStub : public ServiceInterface { + public: + explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} + ~GRPCStub() override; + + tensorflow::Status TransferToClient( + const TransferToClientRequest* arg, + TransferToClientResponse* result) override; + + tensorflow::Status TransferToServer( + const TransferToServerRequest* arg, + TransferToServerResponse* result) override; + + tensorflow::Status TransferToInfeed( + const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; + + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; + + tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; + + tensorflow::Status LoadComputationSnapshot( + const LoadComputationSnapshotRequest* request, + LoadComputationSnapshotResponse* result) override; + + tensorflow::Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) override; + + tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; + + tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; + + tensorflow::Status ExecuteGraphParallel( + const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; + + tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; + + tensorflow::Status WaitForExecution( + const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; + + tensorflow::Status DeconstructTuple( + const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; + + tensorflow::Status GetComputationStats( + const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; + + tensorflow::Status GetComputationGraphStats( + const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; + + tensorflow::Status GetComputationShape( + const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; + + tensorflow::Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; + + tensorflow::Status GetDeviceHandles( + const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; + + tensorflow::Status CreateChannelHandle( + const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; + + // Methods used by ComputationBuilder. + tensorflow::Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; + + tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; + tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; + + tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; + + tensorflow::Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; + + tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; + + tensorflow::Status ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; + + // Methods used by Computation. + tensorflow::Status SnapshotComputation( + const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) override; + + // Methods used by GlobalData. + tensorflow::Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; + + grpc::XlaService::Stub* service() { return grpc_stub_; } + + private: + grpc::XlaService::Stub* grpc_stub_; + + TF_DISALLOW_COPY_AND_ASSIGN(GRPCStub); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_RPC_GRPC_STUB_H_ diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto new file mode 100644 index 0000000000000000000000000000000000000000..c47164ee1b7657ae378a053f553442bee751753e --- /dev/null +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -0,0 +1,225 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA service API. +// +// Users 1) build up computations and 2) create allocations via this API. +// Computations are composed of data flowing between arbitrarily-sized +// vector-oriented operations. +// +// Users build up computations using a ComputationHandle, and talk about +// allocations using GlobalDataHandles. +// +// There are currently no checkpointing capabilities or distribution/replication +// guarantees. The service runs on a single machine (e.g. one task) and that is +// its failure domain. +// +// Canonical example of "alpha * X + Y": +// * Make a computation. +// * Add alpha and X and Y as parameters. +// * Request the multiplication of alpha and X. +// * Request the addition of that result and Y. +// +// Then, pass the computation and appropriately shaped inputs to the XLA +// service's Execute method, which provides a result as a GlobalDataHandle. +// +// All data in XLA computations are conceptually immutable. +// +// Note: this API is subject to change / refinement over time -- use the +// provided client libraries to insulate code from changes to this service API. + +syntax = "proto3"; + +import "tensorflow/compiler/xla/xla.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; + +package xla; + +service XlaService { + ///////////////////////// + // Global data requests + + // Unregisters a global allocation. + // + // If the handle given is not currently allocated, a NOT_FOUND status is + // returned. + rpc Unregister(UnregisterRequest) returns (UnregisterResponse) { + } + + // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each + // element in the tuple. + rpc DeconstructTuple(DeconstructTupleRequest) + returns (DeconstructTupleResponse) { + } + + // Unpack requests that a global data handle, with a tuple shape, has global + // data handles created for each of its constituent members. This is the + // equivalent of the "destructuring assignment" present in various programming + // languages. + rpc Unpack(UnpackRequest) returns (UnpackResponse) { + } + + // Requests the shape of the referenced global data. + rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { + } + + // Requests the program shape of the referenced computation. + rpc GetComputationShape(GetComputationShapeRequest) + returns (GetComputationShapeResponse) { + } + + // Requests the statistics of the given computation. + rpc GetComputationStats(ComputationStatsRequest) + returns (ComputationStatsResponse) { + } + + // Requests the statistics of the given computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + rpc GetComputationGraphStats(ComputationGraphStatsRequest) + returns (ComputationStatsResponse) { + } + + // Loads a variable number of values with a given element type from ColumnIO. + rpc LoadData(LoadDataRequest) returns (LoadDataResponse) { + } + + // Transfers the given global data to the client in the form of a Literal. + rpc TransferToClient(TransferToClientRequest) + returns (TransferToClientResponse) { + } + + // Transfers the given literal to the server to be stored in a global + // allocation, which is returned. + rpc TransferToServer(TransferToServerRequest) + returns (TransferToServerResponse) { + } + + // Transfers the given literal to the Infeed buffer of the device. + rpc TransferToInfeed(TransferToInfeedRequest) + returns (TransferToInfeedResponse) { + } + + // Transferred literal from the Outfeed buffer of the device. + rpc TransferFromOutfeed(TransferFromOutfeedRequest) + returns (TransferFromOutfeedResponse) { + } + + // Resets the device, clearing all existing state on the device. + rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { + } + + // Tests if an expression is a compile-time constant. + rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { + } + + // Computes the value of a constant expression. + rpc ComputeConstant(ComputeConstantRequest) + returns (ComputeConstantResponse) { + } + + // Computes the value of a constant expression. The request contains the + // computation graph for the constant expression. + rpc ComputeConstantGraph(ComputeConstantGraphRequest) + returns (ComputeConstantResponse) { + } + + // Retrieves the inferred shape for a value within a computation. + rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) { + } + + // Requests one or more device handles from the target. The returned device + // handles can be used to specify the device on which to execute computations + // or transfer data. + rpc GetDeviceHandles(GetDeviceHandlesRequest) + returns (GetDeviceHandlesResponse) { + } + + // Creates a channel handle that can be used to transfer data between + // two computations via a pair of Send and Recv instructions. + rpc CreateChannelHandle(CreateChannelHandleRequest) + returns (CreateChannelHandleResponse) { + } + + // Requests that the referenced computation be specialized for the provided + // arguments for subsequent execution. This permits things such as value + // specialization. + rpc Specialize(SpecializeRequest) returns (SpecializeResponse) { + } + + // Modifies the provided computation so that subsequent executions + // will compute the provided ComputationDataHandle, rather than the + // last expression enqueued on that Computation. + rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { + } + + // Computation creates a new computation with the given name. + // A unique ComputationHandle is returned. + rpc Computation(ComputationRequest) returns (ComputationResponse) { + } + + // Adds a new op to a computation. + rpc CreateOp(OpRequest) returns (OpResponse) { + } + + // Invokes the provided computation with the provided global data passed as + // immutable arguments. Returns global data output and execution timing. + rpc Execute(ExecuteRequest) returns (ExecuteResponse) { + } + + // Invokes the provided computation with the provided global data passed as + // immutable arguments. The request contains the whole computation graph. + // Returns global data output and execution timing. + rpc ExecuteGraph(ExecuteGraphRequest) returns (ExecuteResponse) { + } + + // Invokes the provided list of computations in parallel with the provided + // global data for each computation. Returns a list of global data output and + // execution timing. + rpc ExecuteParallel(ExecuteParallelRequest) + returns (ExecuteParallelResponse) { + } + + // Invokes the provided list of computations in parallel with the provided + // global data for each computation. Returns a list of global data output and + // execution timing. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) + returns (ExecuteParallelResponse) { + } + + // Invokes the provided computation with the provided global data passed as + // immutable arguments. Returns a handle to the execution. + rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { + } + + // Waits until the given execution (aysnchronously launched) is complete, and + // returns the global data output. + rpc WaitForExecution(WaitForExecutionRequest) + returns (WaitForExecutionResponse) { + } + + // Serializes a computation to proto form, so it can be loaded via + // LoadComputationSnapshot. + rpc SnapshotComputation(SnapshotComputationRequest) + returns (SnapshotComputationResponse) { + } + + // Loads a computation from a captured snapshot. + rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) + returns (LoadComputationSnapshotResponse) { + } +} diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 83c67ed9368bc617a90c528f200b566ee8754edd..9831a09c1fd491a5553fd46e9d7e33ca3e5f4891 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -106,6 +106,7 @@ tf_cc_test( ":bfloat16_normalization", ":bfloat16_support", ":hlo", + ":hlo_verifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", @@ -118,6 +119,42 @@ tf_cc_test( ], ) +cc_library( + name = "bfloat16_propagation", + srcs = ["bfloat16_propagation.cc"], + hdrs = ["bfloat16_propagation.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_dce", + ":hlo_pass", + ":tuple_simplifier", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_propagation_test", + srcs = ["bfloat16_propagation_test.cc"], + deps = [ + ":bfloat16_propagation", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -145,7 +182,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", ], ) @@ -247,6 +285,46 @@ cc_library( ], ) +tf_cc_test( + name = "dfs_hlo_visitor_with_default_test", + srcs = ["dfs_hlo_visitor_with_default_test.cc"], + deps = [ + ":hlo", + ":hlo_runner", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "pattern_matcher", + hdrs = ["pattern_matcher.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "pattern_matcher_test", + srcs = ["pattern_matcher_test.cc"], + deps = [ + ":hlo", + ":pattern_matcher", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_reachability", srcs = ["hlo_reachability.cc"], @@ -585,6 +663,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -674,7 +753,6 @@ cc_library( ":computation_layout", ":device_memory_allocator", ":hlo", - ":hlo_cost_analysis", ":hlo_execution_profile", ":hlo_graph_dumper", ":pool", @@ -718,6 +796,7 @@ cc_library( hdrs = ["llvm_compiler.h"], deps = [ ":compiler", + "//tensorflow/core:lib_internal", "@llvm//:core", ], ) @@ -950,6 +1029,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1026,6 +1106,38 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_module_group_metadata", + srcs = ["hlo_module_group_metadata.cc"], + hdrs = ["hlo_module_group_metadata.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_module_group_util", + srcs = ["hlo_module_group_util.cc"], + hdrs = ["hlo_module_group_util.h"], + deps = [ + ":hlo", + ":hlo_module_group_metadata", + ":hlo_reachability", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], @@ -1057,6 +1169,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1094,6 +1207,19 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_creation_utils", + srcs = ["hlo_creation_utils.cc"], + hdrs = ["hlo_creation_utils.h"], + deps = [ + ":hlo", + ":shape_inference", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + ], +) + cc_library( name = "batchnorm_expander", srcs = ["batchnorm_expander.cc"], @@ -1102,7 +1228,6 @@ cc_library( ":hlo", ":hlo_pass", ":hlo_query", - ":shape_inference", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1114,6 +1239,20 @@ cc_library( ], ) +cc_library( + name = "gather_expander", + srcs = ["gather_expander.cc"], + hdrs = ["gather_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", @@ -1141,9 +1280,10 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_creation_utils", ":hlo_pass", ":hlo_query", - ":shape_inference", + ":pattern_matcher", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1177,6 +1317,53 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_expander_test", + srcs = ["gather_expander_test.cc"], + deps = [ + ":gather_expander", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:test_macros_header", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "conditional_simplifier", + srcs = ["conditional_simplifier.cc"], + hdrs = ["conditional_simplifier.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "conditional_simplifier_test", + srcs = ["conditional_simplifier_test.cc"], + deps = [ + ":conditional_simplifier", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop_simplifier", srcs = ["while_loop_simplifier.cc"], @@ -1199,6 +1386,7 @@ tf_cc_test( ":while_loop_simplifier", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:lib", "//tensorflow/core:test", ], ) @@ -1432,6 +1620,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -2313,6 +2502,24 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + ], +) + +tf_cc_test( + name = "hlo_proto_util_test", + srcs = ["hlo_proto_util_test.cc"], + deps = [ + ":hlo", + ":hlo_proto", + ":hlo_proto_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -2352,6 +2559,7 @@ cc_library( srcs = ["hlo_runner.cc"], hdrs = ["hlo_runner.h"], deps = [ + ":computation_placer", ":executable", ":hlo", ":transfer_manager", @@ -2368,6 +2576,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -2411,7 +2620,9 @@ cc_library( deps = [ ":call_inliner", ":hlo", + ":hlo_creation_utils", ":tuple_util", + "//tensorflow/core:lib", ], ) @@ -2454,6 +2665,21 @@ tf_cc_test( ], ) +cc_library( + name = "despecializer", + srcs = ["despecializer.cc"], + hdrs = ["despecializer.h"], + deps = [ + ":bfloat16_normalization", + ":defuser", + ":hlo", + ":hlo_pass", + ":hlo_pass_pipeline", + ":implicit_broadcast_remover", + "//tensorflow/compiler/xla:statusor", + ], +) + cc_library( name = "source_map_util", srcs = ["source_map_util.cc"], @@ -2465,17 +2691,3 @@ cc_library( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fb857559f972a220a19b108baa4c441e09b90e1f..8d26938c6e59beaea400ab605d84be5a6ddebf6d 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -26,10 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -44,8 +45,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { + namespace { +namespace m = match; + // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -105,6 +109,7 @@ HloComputation* CreateScalarBinaryComputation(HloModule* module, module->AddEmbeddedComputation(b.Build(scalar_op)); return scalar_computation; } + } // namespace // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -122,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBitcastConvert(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleConcatenate(HloInstruction* concatenate) override; @@ -300,7 +307,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplication on platforms where it causes a slowdown. + // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; }; @@ -348,8 +355,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( } Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { - auto lhs = add->mutable_operand(0); - auto rhs = add->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); + // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { @@ -364,7 +372,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Canonicalization: Put constants on the right. This makes the reassociation // rules below simpler. VLOG(10) << "trying transform [Const + A => A + Const]"; - if (lhs->IsConstant() && !rhs->IsConstant()) { + if (Match(add, m::Add(m::Constant(), m::NonConstant()))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); @@ -377,20 +385,13 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // (A + C1) + (B + C2) => A + B + (C1 + C2). // VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; - if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && - !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { - auto* c1 = lhs->mutable_operand(1); - auto* c2 = rhs; - TF_ASSIGN_OR_RETURN( - Shape sum_of_constants_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2)); - - auto* sum_of_constants = - computation_->AddInstruction(HloInstruction::CreateBinary( - sum_of_constants_shape, HloOpcode::kAdd, c1, c2)); + HloInstruction *a, *c1, *c2; + if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2)))) { + TF_ASSIGN_OR_RETURN(auto* sum_of_constants, + MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); return ReplaceWithNewInstruction( - add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, - lhs->mutable_operand(0), + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, sum_of_constants)); } @@ -399,11 +400,11 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. - if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { + HloInstruction* op; + if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { return ReplaceWithNewInstruction( - bitcast, HloInstruction::CreateUnary( - bitcast->shape(), HloOpcode::kBitcast, - bitcast->mutable_operand(0)->mutable_operand(0))); + bitcast, + HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op)); } // All bitcasts can be eliminated (assuming layout constraints are // satisified). @@ -411,13 +412,19 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleBitcastConvert( + HloInstruction* bitcast) { + // Eliminate bitcast converts between same shape. + ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (copy->operand(0)->opcode() == HloOpcode::kCopy) { + HloInstruction* op; + if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary( - copy->shape(), HloOpcode::kCopy, - copy->mutable_operand(0)->mutable_operand(0))); + copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); @@ -457,12 +464,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } else if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. - bool is_effective_low_pad = - operands[0]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[0]->operand(0)->shape()); - bool is_effective_high_pad = - operands[1]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[1]->operand(0)->shape()); + bool is_effective_low_pad = Match( + operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); + bool is_effective_high_pad = Match( + operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); if (!is_effective_low_pad && !is_effective_high_pad) { return Status::OK(); } @@ -516,12 +521,24 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { return ReplaceInstruction( constant, BuildTupleConstant(computation_, constant->literal())); } + + // If a literal is all the same element replace it with a scalar broadcast. + if (ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsAllFirst()) { + std::unique_ptr unique_scalar = + MakeUnique(constant->literal().GetFirstScalarLiteral()); + HloInstruction* scalar = computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(unique_scalar))); + return ReplaceWithNewInstruction( + constant, + HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); + } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { - auto lhs = sub->mutable_operand(0); - auto rhs = sub->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { @@ -530,7 +547,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { // Canonicalize subtraction of a constant to addition. VLOG(10) << "trying transform [A - Const => A + (-Const)]"; - if (rhs->IsConstant() && !lhs->IsConstant()) { + if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { HloInstruction* negative_const = computation_->AddInstruction( HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); return ReplaceWithNewInstruction( @@ -542,56 +559,53 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { } Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - auto lhs = divide->mutable_operand(0); - auto rhs = divide->mutable_operand(1); + Shape* shape; + HloInstruction *a, *b, *c, *d; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) { return Status::OK(); } // exp(A)/exp(B) => exp(A-B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) + .WithShape(m::Shape(&shape)))) { VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); - HloInstruction* subtract = - computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + HloInstruction* subtract = computation_->AddInstruction( + HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, - subtract)); + divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract)); } // A/exp(B) => A*exp(-B) - if (rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) { VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b)); HloInstruction* new_exp = computation_->AddInstruction( HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, new_exp)); } // A/pow(B,C) => A*pow(B,-C) - if (rhs->opcode() == HloOpcode::kPower) { + if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) { VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); // The output shape of the created negate operator should be the same as the // input. - const Shape& negate_shape = rhs->operand(1)->shape(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); + const Shape& negate_shape = c->shape(); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c)); // And the power operator should retain the output shape of the old one. - const Shape& new_power_shape = rhs->shape(); - HloInstruction* new_power = computation_->AddInstruction( - HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, - rhs->mutable_operand(0), negate)); + const Shape& new_power_shape = b->shape(); + HloInstruction* new_power = + computation_->AddInstruction(HloInstruction::CreateBinary( + new_power_shape, HloOpcode::kPower, b, negate)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_power)); + divide->shape(), HloOpcode::kMultiply, a, new_power)); } // Simplifying integral division would produce unexpected results. @@ -603,65 +617,46 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // // (Backends can do this transformation, but generally only if the constant is // a scalar.) - if (lhs->opcode() != HloOpcode::kConstant && - rhs->opcode() == HloOpcode::kConstant) { + if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { HloInstruction* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(lhs->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = - computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kDivide, one, rhs)); + Literal::One(a->shape().element_type()).CloneToUnique())); + HloInstruction* inverse = computation_->AddInstruction( + HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, inverse)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, inverse)); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) - if (lhs->opcode() == HloOpcode::kDivide && - rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - const Shape a_times_d_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, - lhs->operand(0), rhs->operand(1))); - auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( - a_times_d_shape, HloOpcode::kMultiply, lhs->mutable_operand(0), - rhs->mutable_operand(1))); - TF_ASSIGN_OR_RETURN( - const Shape b_times_c_shape, - ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, - lhs->operand(1), rhs->operand(0))); - auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), - rhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kDivide, a_times_d, b_times_c)); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), + m::Divide(m::Op(&c), m::Op(&d))))) { + TF_ASSIGN_OR_RETURN(auto a_times_d, + MakeBinaryHlo(HloOpcode::kMultiply, a, d)); + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); + TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, + a_times_d, b_times_c)); + + return ReplaceInstruction(divide, new_divide); } // (A / B) / C => A / (B * C) - if (lhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(const Shape b_times_c_shape, - ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs->operand(1), rhs)); - auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - lhs->mutable_operand(0), b_times_c)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a, b_times_c)); } // A / (B / C) => (A*C) / B - if (rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(const Shape a_times_c_shape, - ShapeInference::InferBinaryOpShape( - HloOpcode::kMultiply, lhs, rhs->operand(1))); - auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - a_times_c_shape, HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { + TF_ASSIGN_OR_RETURN(auto a_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, a, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - a_times_c, rhs->mutable_operand(0))); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a_times_c, b)); } return Status::OK(); @@ -669,8 +664,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); int64 lhs_collapsing_dim = dot->dot_dimension_numbers().lhs_contracting_dimensions(0); if (lhs->IsRank2Transpose()) { @@ -787,8 +782,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, @@ -918,8 +913,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( } Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or // below. @@ -971,8 +966,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { - auto lhs = multiply->mutable_operand(0); - auto rhs = multiply->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { @@ -985,10 +980,9 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } // exp(A) * exp(B) => exp(A+B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( - multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + multiply->shape(), HloOpcode::kAdd, lhs, rhs)); return ReplaceWithNewInstruction( multiply, HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); @@ -999,20 +993,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); - auto operand = log->mutable_operand(0); - if (operand->opcode() == HloOpcode::kExp && - ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { + HloInstruction *a, *b; + if (Match(log, m::Log(m::Exp(m::Op(&a)))) && + ReplaceInstructionIfSameShape(log, a)) { return Status::OK(); } // ln(pow(A,B)) => B*ln(A) - if (operand->opcode() == HloOpcode::kPower) { - auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( - log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); + if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { + auto new_log = computation_->AddInstruction( + HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); return ReplaceWithNewInstruction( - log, - HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, - new_log, operand->mutable_operand(1))); + log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, b)); } return Status::OK(); @@ -1115,11 +1108,12 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, } // namespace Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { - auto operand = broadcast->mutable_operand(0); + HloInstruction* operand; + CHECK(Match(broadcast, m::Broadcast(m::Op(&operand)))); + auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. - if (std::is_sorted(broadcast->dimensions().begin(), - broadcast->dimensions().end()) && + if (std::is_sorted(dims.begin(), dims.end()) && ShapeUtil::ElementsIn(broadcast->shape()) == ShapeUtil::ElementsIn(operand->shape())) { VLOG(10) << "transform broadcast(X) -> reshape(X) where " @@ -1137,8 +1131,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { VLOG(10) << "transform broadcast(X) -> transpose(X) where " "n(broadcast(X)) == n(X)"; return ReplaceWithNewInstruction( - broadcast, HloInstruction::CreateTranspose(broadcast->shape(), operand, - broadcast->dimensions())); + broadcast, + HloInstruction::CreateTranspose(broadcast->shape(), operand, dims)); } // A broadcast of a reshape which merely inserts 1-sized dimensions can @@ -1152,7 +1146,6 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { if (merely_inserts_or_deletes_1_sized_dimensions && deleted_indices.empty()) { std::reverse(inserted_indices.begin(), inserted_indices.end()); - auto dims = broadcast->dimensions(); for (auto inserted_index : inserted_indices) { dims.erase(dims.begin() + inserted_index); } @@ -1196,6 +1189,19 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return user->ReplaceAllUsesWith(new_broadcast); } } + return Status::OK(); + } + + // Merge two consecutive broadcasts into a single one. + if (operand->opcode() == HloOpcode::kBroadcast) { + std::vector new_dimensions; + for (auto dim : operand->dimensions()) { + new_dimensions.push_back(dims[dim]); + } + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateBroadcast( + broadcast->shape(), operand->mutable_operand(0), new_dimensions)); } return Status::OK(); } @@ -1214,30 +1220,28 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { // Complex(Real(c), Imag(c)) -> c Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { - auto real = complex->mutable_operand(0); - auto imag = complex->mutable_operand(1); - if (real->opcode() == HloOpcode::kReal && - imag->opcode() == HloOpcode::kImag && - real->operand(0) == imag->operand(0)) { - return ReplaceInstruction(complex, real->mutable_operand(0)); + HloInstruction *c0, *c1; + if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) && + c0 == c1) { + return ReplaceInstruction(complex, c0); } return Status::OK(); } // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { - auto operand = real->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(real, operand->mutable_operand(0)); + HloInstruction* op; + if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) { + return ReplaceInstruction(real, op); } return Status::OK(); } // Imag(Complex(r, i)) -> i Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { - auto operand = imag->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(imag, operand->mutable_operand(1)); + HloInstruction* op; + if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) { + return ReplaceInstruction(imag, op); } return Status::OK(); } @@ -1290,17 +1294,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { padding_dimension->set_edge_padding_high(0); } } - TF_ASSIGN_OR_RETURN(Shape nonzero_pad_shape, - ShapeInference::InferPadShape(pad->operand(0)->shape(), - pad->operand(1)->shape(), - nonzero_padding)); + + TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad, + MakePadHlo(pad->mutable_operand(0), + pad->mutable_operand(1), nonzero_padding)); // Copy the layout from the original pad instructions. The new pad and the // slice instruction should all have the same layout. - TF_RETURN_IF_ERROR( - LayoutUtil::CopyLayoutBetweenShapes(pad->shape(), &nonzero_pad_shape)); - HloInstruction* nonzero_pad = computation_->AddInstruction( - HloInstruction::CreatePad(nonzero_pad_shape, pad->mutable_operand(0), - pad->mutable_operand(1), nonzero_padding)); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + pad->shape(), nonzero_pad->mutable_shape())); // Second, construct the slice instruction to perform the negative padding. std::vector start_indices; @@ -1313,7 +1314,7 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (padding_dimension.edge_padding_low() < 0) { start = -1 * padding_dimension.edge_padding_low(); } - int64 end = nonzero_pad_shape.dimensions(i); + int64 end = nonzero_pad->shape().dimensions(i); if (padding_dimension.edge_padding_high() < 0) { end += padding_dimension.edge_padding_high(); } @@ -1322,16 +1323,14 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { strides.push_back(1); } - // Verify that the slice shape matches the pad shape. TF_ASSIGN_OR_RETURN( - Shape inferred_slice_shape, - ShapeInference::InferSliceShape(nonzero_pad_shape, start_indices, - end_indices, strides)); - TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape())); + HloInstruction * slice, + MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); + + // Verify that the slice shape matches the pad shape. + TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape())); - std::unique_ptr slice = HloInstruction::CreateSlice( - pad->shape(), nonzero_pad, start_indices, end_indices, strides); - return ReplaceWithNewInstruction(pad, std::move(slice)); + return ReplaceInstruction(pad, slice); } return Status::OK(); @@ -1339,8 +1338,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - auto lhs = power->mutable_operand(0); - auto rhs = power->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( Literal::One(power->shape().element_type()).CloneToUnique()); @@ -1360,9 +1359,10 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { } // pow(exp(A),B) => exp(A*B) - if (lhs->opcode() == HloOpcode::kExp) { + HloInstruction *a, *b; + if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) { auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( - power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); + power->shape(), HloOpcode::kMultiply, a, b)); return ReplaceWithNewInstruction( power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, a_times_b)); @@ -1412,6 +1412,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } +// TODO(b/74536353): do this simplification for BroadcastDimOne as well. StatusOr AlgebraicSimplifierVisitor:: TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast) { @@ -1604,6 +1605,14 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { return ReplaceInstruction(dynamic_update_slice, update); } + + // If any dimension of update is 0, elide the DynamicUpdateSlice. This + // optimization becomes invalid should we later prefer to warn about out of + // bound indices. + if (ShapeUtil::HasZeroElements(update->shape())) { + return ReplaceInstruction(dynamic_update_slice, + dynamic_update_slice->mutable_operand(0)); + } return Status::OK(); } @@ -1686,7 +1695,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { HloInstruction::CreateReshape(reduce->shape(), arg)); return ReplaceWithNewInstruction( reduce, HloInstruction::CreateMap(reduce->shape(), - {reshape, init_value}, function)); + {init_value, reshape}, function)); } return Status::OK(); } @@ -1711,18 +1720,29 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( function)); } - VLOG(10) << "Considering folding Pad: " << operand->ToString() - << "\ninto reduce-window: " << reduce_window->ToString(); - // This optimization folds a pad op into reduce_window. - if (operand->opcode() != HloOpcode::kPad) { + HloInstruction* pad; + const HloInstruction* convert = nullptr; + if (operand->opcode() == HloOpcode::kPad) { + pad = operand; + } else if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->opcode() == HloOpcode::kPad) { + convert = operand; + pad = operand->mutable_operand(0); + } else { VLOG(10) << "Not folding pad into reduce-window as there is no pad."; return Status::OK(); } + VLOG(10) << "Considering folding Pad: " << pad->ToString() + << "\ninto reduce-window: " << reduce_window->ToString() + << (convert != nullptr ? tensorflow::strings::StrCat( + "\nvia convert: ", convert->ToString()) + : ""); + // Do not fold interior padding into ReduceWindow since the backends do not // support it. - const PaddingConfig& pad_config = operand->padding_config(); + const PaddingConfig& pad_config = pad->padding_config(); if (HasInteriorPadding(pad_config)) { VLOG(10) << "Not folding pad into reduce-window due to interior padding."; return Status::OK(); @@ -1730,14 +1750,27 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If reduce_window already has padding, the pad value of the pad op and the // init value of reduce_window must match to allow folding the pad. - const HloInstruction* pad_value = operand->operand(1); + const HloInstruction* pad_value = pad->operand(1); const HloInstruction* reduce_init_value = reduce_window->operand(1); if (pad_value != reduce_init_value) { + auto literals_are_equivalent = [&] { + auto& pad_literal = pad_value->literal(); + auto& reduce_init_literal = reduce_init_value->literal(); + if (pad_literal == reduce_init_literal) { + return true; + } + auto converted_pad_literal = pad_literal.ConvertToShape( + reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + if (!converted_pad_literal.ok()) { + return false; + } + return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - pad_value->literal() != reduce_init_value->literal()) { + !literals_are_equivalent()) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1746,7 +1779,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // If the pad puts a single non-identity value in each window that we're // reducing, then this is a broadcast. - HloInstruction* pad_operand = operand->mutable_operand(0); + HloInstruction* pad_operand = pad->mutable_operand(0); auto is_effective_broadcast = [&] { if (window_util::HasStride(window)) { VLOG(10) << "Window has stride."; @@ -1790,6 +1823,18 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Found window covers a single unpadded element."; return true; }; + + HloInstruction* new_reduce_window_operand; + if (convert != nullptr) { + new_reduce_window_operand = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad_operand->shape(), + convert->shape().element_type()), + pad_operand)); + } else { + new_reduce_window_operand = pad_operand; + } + if (is_effective_broadcast()) { VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; auto fadd = [this](std::unique_ptr x) { @@ -1798,7 +1843,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcastSequence( /*output_shape=*/reduce_window->shape(), - /*operand=*/pad_operand, fadd)); + /*operand=*/new_reduce_window_operand, fadd)); } // Carry out the folding of the pad into reduce_window. @@ -1815,10 +1860,11 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( window_dim.set_padding_high(window_dim.padding_high() + pad_dim.edge_padding_high()); } + return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateReduceWindow( /*shape=*/reduce_window->shape(), - /*operand=*/pad_operand, + /*operand=*/new_reduce_window_operand, /*init_value=*/reduce_window->mutable_operand(1), /*window=*/new_window, /*reduce_computation=*/function)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 43315f5cdc7afbe79039420320f4a0d0535e11f1..c48196e861a559a5abfa360841ec70b39356fa2b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -// A pass which performs AlgebraicSimplications. +// A pass which performs algebraic simplifications. class AlgebraicSimplifier : public HloPassInterface { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to @@ -57,10 +57,10 @@ class AlgebraicSimplifier : public HloPassInterface { bool is_layout_sensitive_; ValidBitcastCallback valid_bitcast_callback_; - // Enable dot simplication on platforms where it is profitable. + // Enable dot simplification on platforms where it is profitable. bool enable_dot_strength_reduction_; - // Enable convolution simplication on platforms where it is profitable. + // Enable convolution simplification on platforms where it is profitable. bool enable_conv_simplification_; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0f08eb3a3267c4b7b04958270a5788fc48d3fa04..20c549562d5153c802c1e675a8ff1c92426b8832 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -35,6 +35,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" +using ::testing::ElementsAre; + namespace xla { namespace { @@ -162,6 +164,37 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({3.14f, 3.14f, 3.14f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); +} + +TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({3.14, 3.14, 4}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -2305,6 +2338,91 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); } +// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to +// ReduceWindow(Convert(op), x). +TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create operand to the pad. + HloInstruction* parameter = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0")); + + // Create the pad. + PaddingConfig padding = MakeNoPaddingConfig(4); + padding.mutable_dimensions(1)->set_edge_padding_low(1); + padding.mutable_dimensions(3)->set_edge_padding_high(2); + + HloInstruction* pad_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding)); + + HloInstruction* convert = + builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(pad->shape(), F32), pad)); + + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module.AddEmbeddedComputation(builder.Build()); + } + + // Create the reduce-window. + Window window; + for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { + auto* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(10); + dim->set_padding_high(100); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + const Shape reduce_window_shape = + ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); + HloInstruction* reduce_init_value = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); + HloInstruction* reduce_window = + builder.AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window_shape, convert, reduce_init_value, window, + add_computation)); + + // Build the computation and run the simplifier. + auto computation = module.AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, reduce_window); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + + // Running simplification again should not result in any further changes. + ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + + // Verify the result + root = computation->root_instruction(); + EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant())); + EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) + << ShapeUtil::HumanString(root->shape()) << " vs " + << ShapeUtil::HumanString(reduce_window_shape); + EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); + EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); + EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); + EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); +} + TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { HloComputation::Builder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); @@ -2431,6 +2549,55 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* input_array = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({3, 4}))); + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, input_array, {1})); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, inner_bcast, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_THAT(root->dimensions(), ElementsAre(2)); +} + +// Test that two consecutive broadcasts can be merged to one. +TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + // The initial dimensions go to places 0 and 2 in the 3-dim array, + // and to places 1 and 3 in the 4-dim array, + HloInstruction* inner_bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r3f32, param0, {0, 2})); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, inner_bcast, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Parameter(0))); + EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -2769,6 +2936,29 @@ DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { {/*m=*/1, /*k=*/16, /*n=*/1}, // }; +// Test that DynamicUpdateSlice update param with any dimension equal to zero +// gets removed. +TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { + HloComputation::Builder builder(TestName()); + const Shape dslice_shape = ShapeUtil::MakeShape(F32, {10}); + HloInstruction* const operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, dslice_shape, "operand")); + const Shape update_shape = ShapeUtil::MakeShape(F32, {0}); + HloInstruction* const update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + HloInstruction* const start_indices = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0}))); + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + dslice_shape, operand, update, start_indices)); + const HloComputation* const computation = + module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), operand); +} + INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 4e80679c11dfdf7fdf8077a9f354139a4cab6803..4f819a743c48f30df8dde00ece72a0b4e1748802 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -34,40 +34,54 @@ StatusOr AllocationTracker::Register( std::unique_ptr shaped_buffer, const string& tag) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(std::move(shaped_buffer), tag); + std::vector> replicated_buffers; + replicated_buffers.emplace_back(std::move(shaped_buffer)); + return RegisterInternal(std::move(replicated_buffers), tag); +} + +StatusOr AllocationTracker::RegisterReplicatedBuffers( + std::vector> replicated_buffers, + const string& tag) { + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "RegisterReplicatedBuffers"; + return RegisterInternal(std::move(replicated_buffers), tag); } StatusOr AllocationTracker::RegisterInternal( - std::unique_ptr shaped_buffer, const string& tag) { + std::vector> replicated_buffers, + const string& tag) { VLOG(2) << "RegisterInternal(" - << "tag: \"" << tag << "\" " - << "shaped_buffer: " << *shaped_buffer; - if (shaped_buffer->platform() != backend_->platform()) { - return InvalidArgument( - "AllocationTracker for platform %s cannot register buffer from " - "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer->platform()->Name().c_str()); + << "tag: \"" << tag << "\" with " << replicated_buffers.size() + << " shaped_buffers."; + for (const auto& shaped_buffer : replicated_buffers) { + VLOG(2) << "shaped_buffer:" << *shaped_buffer; + if (shaped_buffer->platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer->platform()->Name().c_str()); + } } int64 handle = next_handle_++; - std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); - for (const ShapeIndex& index : shape_indices) { - AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal()); + for (auto& shaped_buffer : replicated_buffers) { + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal()); + } + handle_to_shaped_buffers_[handle].emplace_back(std::move(shaped_buffer)); } + GlobalDataHandle result; result.set_handle(handle); - - handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); - VLOG(2) << "handle: " << handle; - return result; } @@ -75,23 +89,35 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; - TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); - std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); - for (const ShapeIndex& index : shape_indices) { - TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + for (const auto& shaped_buffer : replicated_buffers) { + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); + } } + return Reset(data); +} - // Keep a nullptr as a tombstone for unregistered handles. This enables better - // error messages. That is, "handle has been deallocated" versus "handle does - // not exist". - handle_to_shaped_buffer_.at(data.handle()).reset(); - +Status AllocationTracker::Reset(const GlobalDataHandle& data) { + // Keep a nullptr as a tombstone for unregistered handles. This enables + // better error messages. That is, "handle has been deallocated" versus + // "handle does not exist". + auto it = handle_to_shaped_buffers_.find(data.handle()); + if (it == handle_to_shaped_buffers_.end()) { + return NotFound("no allocation record for global data handle: %lld", + data.handle()); + } + for (auto& shaped_buffer : it->second) { + shaped_buffer.reset(); + } return tensorflow::Status::OK(); } @@ -99,7 +125,11 @@ StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); - TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + // We only need to care about replica id 0 here, since the GlobalDataHandle is + // the same for all buffers across replicas. + const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); @@ -109,7 +139,7 @@ StatusOr> AllocationTracker::DeconstructTuple( TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { - return Unimplemented("deconstructing nested tuples not yet supported"); + return Unimplemented("Deconstructing nested tuples is not implemented."); } std::vector element_handles; @@ -122,37 +152,55 @@ StatusOr> AllocationTracker::DeconstructTuple( shaped_buffer->platform(), shaped_buffer->device_ordinal()); element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), /*index=*/{}); + std::vector> replicated_buffers; + replicated_buffers.emplace_back(std::move(element_buffer)); TF_ASSIGN_OR_RETURN( GlobalDataHandle element_handle, - RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + RegisterInternal(std::move(replicated_buffers), "deconstructed tuple")); element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr> AllocationTracker::Resolve( const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveForReplica( + const GlobalDataHandle& data, int replica_id) { + tensorflow::mutex_lock lock(mutex_); + TF_ASSIGN_OR_RETURN(std::vector replicated_buffers, + ResolveInternal(data)); + if (replica_id >= replicated_buffers.size()) { + return InvalidArgument( + "Requesting buffer for replica %d, but found buffers only for %lu " + "replicas.", + replica_id, replicated_buffers.size()); + } + return replicated_buffers[replica_id]; +} + +StatusOr> AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_shaped_buffer_.find(data.handle()); - if (it == handle_to_shaped_buffer_.end()) { + auto it = handle_to_shaped_buffers_.find(data.handle()); + if (it == handle_to_shaped_buffers_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - ShapedBuffer* shaped_buffer = it->second.get(); - - if (shaped_buffer == nullptr) { - return InvalidArgument("global data handle %lld was previously deallocated", - data.handle()); + std::vector replicated_buffers; + for (const auto& shaped_buffer : it->second) { + if (shaped_buffer == nullptr) { + return InvalidArgument( + "global data handle %lld was previously deallocated", data.handle()); + } + replicated_buffers.push_back(shaped_buffer.get()); } - return shaped_buffer; + return replicated_buffers; } void AllocationTracker::AddAllocationOrIncrementRefCount( diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 807af8694972083d097604a67ee46d2f73d9545a..038aee8541b297d6f91fe2b3bce7455fd9a7084e 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -43,10 +43,17 @@ class AllocationTracker { AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} // Registers a shaped buffer of device memory, and returns a corresponding - // handle that can be used for talking to XLA clients. + // handle that can be used for talking to XLA clients. The given shaped buffer + // will be treated as the buffer corresponding to the only replica. StatusOr Register( std::unique_ptr shaped_buffer, const string& tag); + // Registers a vector of shaped buffers of device memory, one per replica, and + // returns a corresponding handle that can be used for talking to XLA clients. + StatusOr RegisterReplicatedBuffers( + std::vector> replicated_buffers, + const string& tag); + // Unregister the allocation for the given data handle. Status Unregister(const GlobalDataHandle& data); @@ -54,9 +61,17 @@ class AllocationTracker { StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to a shaped buffer, or provide an error - // status to say whether it was not found (or found, but found deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a vector of shaped buffers, one per + // replica, or provide an error status to say whether any of those buffers + // were not found (or found, but found deallocated). + StatusOr> Resolve( + const GlobalDataHandle& data); + + // Resolves a handle from an XLA client and replica id to a shaped buffer, or + // provide an error status to say whether it was not found (or found, but + // found deallocated). + StatusOr ResolveForReplica(const GlobalDataHandle& data, + int replica_id); private: // Data structure encapsulating single memory allocation on the device. @@ -74,13 +89,17 @@ class AllocationTracker { // Internal helper which resolves the given GlobalDataHandle to a // ShapedBuffer. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + StatusOr> ResolveInternal( + const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Internal helper which registers a shaped buffer. + // Internal helper which registers a vector of shaped buffers, one per + // replica. StatusOr RegisterInternal( - std::unique_ptr shaped_buffer, const string& tag) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); + std::vector> replicated_buffers, + const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Resets the shaped buffers corresponding to the given handle. + Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Adds the given device address to the allocation tracker, or if it already // exists, then increment it's reference count. @@ -111,9 +130,10 @@ class AllocationTracker { tensorflow::gtl::FlatMap opaque_to_allocation_map_ GUARDED_BY(mutex_); - // A map from data handle to ShapedBuffer. - tensorflow::gtl::FlatMap> - handle_to_shaped_buffer_ GUARDED_BY(mutex_); + // A map from data handle to a vector of shaped buffers that represent the + // buffers for different replicas. + tensorflow::gtl::FlatMap>> + handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 27ddfd47aa3096afd3e245af1ac3cedd9b48ce4a..38086bd7e121847be6b6b69415cfe87814e7fc24 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -153,6 +152,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; @@ -334,6 +334,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; @@ -419,6 +420,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index cde990e176ddb57a8e93ecc3c60260b2dbae32a8..08d0152e3cfcfcb7ae1e85f72c2f7dc856f5e8b3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -34,6 +34,9 @@ class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); @@ -84,6 +87,25 @@ Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( return Status::OK(); } +namespace { + +// Returns whether hlo has users and all users are conversions from F32 to BF16. +bool AllUsersAreF32ToBF16Converts(const HloInstruction* hlo) { + if (hlo->user_count() == 0 || hlo->shape().element_type() != F32) { + return false; + } + for (const auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + return false; + } + return true; +} + +} // namespace + Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( HloInstruction* hlo) { std::vector bf16_to_f32_operands; @@ -104,22 +126,9 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( } } - bool fold_output_conversion = hlo->user_count() > 0 && - hlo->shape().element_type() == F32 && - bfloat16_support_->SupportsBF16Output(*hlo) && - hlo != computation_->root_instruction(); - if (fold_output_conversion) { - for (auto user : hlo->users()) { - if (user->opcode() == HloOpcode::kConvert && - user->shape().element_type() == BF16) { - continue; - } - // We should not change the output type if any user is not a conversion - // from F32 to BF16. - fold_output_conversion = false; - break; - } - } + const bool fold_output_conversion = + AllUsersAreF32ToBF16Converts(hlo) && + bfloat16_support_->SupportsBF16Output(*hlo); if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { if (has_other_f32_operands || @@ -147,6 +156,10 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kGetTupleElement || // hlo->opcode() == HloOpcode::kInfeed || // hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kSend || // + hlo->opcode() == HloOpcode::kSendDone || // + hlo->opcode() == HloOpcode::kRecv || // + hlo->opcode() == HloOpcode::kRecvDone || // hlo->opcode() == HloOpcode::kConstant || // hlo->opcode() == HloOpcode::kParameter || // hlo->opcode() == HloOpcode::kFusion || // @@ -167,6 +180,52 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { return TryFoldBF16Conversions(hlo); } +Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape()) || + !bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return DefaultAction(crs); + } + + // First use DefaultAction() to handle the operands. It can't handle + // tuple-shaped output. + TF_RETURN_IF_ERROR(DefaultAction(crs)); + + // Then do per-tuple-element handling on the output. + std::vector> per_tuple_element_gtes( + crs->operand_count()); + for (auto user : crs->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return Status::OK(); + } + per_tuple_element_gtes[user->tuple_index()].push_back(user); + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + // Fold conversions only when all the get-tuple-elements' users are + // conversions from F32 to BF16. + auto all_gte_users_are_bf16_convert = [&per_tuple_element_gtes, i]() { + for (auto gte : per_tuple_element_gtes[i]) { + if (!AllUsersAreF32ToBF16Converts(gte)) { + return false; + } + } + return true; + }; + if (!all_gte_users_are_bf16_convert()) { + continue; + } + + ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}) + ->set_element_type(BF16); + for (auto gte : per_tuple_element_gtes[i]) { + TF_RETURN_IF_ERROR(FoldOutputConversions(gte)); + } + } + + return Status::OK(); +} + StatusOr BFloat16ConversionFolding::Run(HloModule* module) { XLA_VLOG_LINES( 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index cb37759439debf41a305ec7dccaa548e1bf234cd..28e71c2054f59ba4d5d096bf7d898161877bb42f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -37,7 +37,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -47,7 +48,8 @@ class TestBFloat16Support : public BFloat16Support { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kSubtract || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -55,7 +57,8 @@ class TestBFloat16Support : public BFloat16Support { bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || - hlo.opcode() == HloOpcode::kGetTupleElement) { + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kCrossReplicaSum) { return true; } return false; @@ -206,4 +209,46 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { EXPECT_EQ(tuple->operand(1), convert0); } +TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* convert_a = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, a)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + HloInstruction* gte_a = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); + HloInstruction* gte_b = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, crs, 1)); + HloInstruction* convert_gte_b = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte_b)); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({gte_a, convert_gte_b})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_EQ(tuple->operand(0), gte_a); + EXPECT_EQ(tuple->operand(1), gte_b); + EXPECT_EQ(gte_a->shape().element_type(), F32); + EXPECT_EQ(gte_b->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0), a); + EXPECT_EQ(crs->operand(1), b); + EXPECT_EQ(a->shape().element_type(), BF16); + EXPECT_EQ(b->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {0}).element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index b032c040e8aff49f9e0fc1ff9a1c1e79ea4bb77f..14c54ddd135af024327f63418b410da1ed3c4fd4 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -152,44 +152,64 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( std::vector operand_types(crs->operand_count()); std::vector output_types(crs->operand_count()); - bool has_f32 = false; - bool has_bf16 = false; - bool has_bf16_output = false; + int64 f32_count = 0; + int64 bf16_count = 0; + bool has_unsupported_bf16_operand = false; + bool has_unsupported_bf16_output = false; for (int64 i = 0; i < crs->operand_count(); ++i) { operand_types[i] = crs->operand(i)->shape().element_type(); output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); - if (operand_types[i] == F32 || output_types[i] == F32) { - has_f32 = true; + if (operand_types[i] == F32) { + f32_count += 1; } else if (operand_types[i] == BF16) { - has_bf16 = true; + bf16_count += 1; + if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + has_unsupported_bf16_operand = true; + } } - if (output_types[i] == BF16) { - has_bf16 = true; - has_bf16_output = true; + if (output_types[i] == F32) { + f32_count += 1; + } else if (output_types[i] == BF16) { + bf16_count += 1; + if (!bfloat16_support_->SupportsBF16Output(*crs)) { + has_unsupported_bf16_output = true; + } } } - for (int64 i = 0; i < crs->operand_count(); ++i) { + if (bf16_count == 0) { + return Status::OK(); + } + + auto should_convert_operand = [&](int64 i) { if (operand_types[i] != BF16) { - continue; + return false; } - if (bfloat16_support_->SupportsBF16Operand(*crs, i) && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { - continue; + if (!bfloat16_support_->SupportsBF16Operand(*crs, i)) { + return true; } - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); - has_f32 = true; - } + if (bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return false; + } + return has_unsupported_bf16_operand || has_unsupported_bf16_output || + f32_count > 0; + }; - if (!has_bf16_output) { - return Status::OK(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (should_convert_operand(i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + f32_count += 1; + bf16_count -= 1; + } } - if (bfloat16_support_->SupportsBF16Output(*crs) && - (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + if (!has_unsupported_bf16_output && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || f32_count == 0 || + bf16_count == 0)) { return Status::OK(); } + std::vector materialized_users = crs->users(); std::vector output_elements(crs->operand_count()); auto original_shape = crs->shape(); for (int64 i = 0; i < crs->operand_count(); ++i) { @@ -209,7 +229,6 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( auto tuple = computation_->AddInstruction( HloInstruction::CreateTuple(output_elements)); - std::vector materialized_users = crs->users(); // Use the crs' shape temporarily, in order to pass checks in // ReplaceUseWith. *tuple->mutable_shape() = crs->shape(); @@ -221,41 +240,37 @@ Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( } Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { - std::vector bf16_operands; - std::vector f32_operands; - bool has_f32 = false; - bool has_bf16 = false; + int f32_count = 0; + int bf16_count = 1; for (int64 i = 0; i < hlo->operand_count(); ++i) { if (hlo->operand(i)->shape().element_type() == F32) { - f32_operands.push_back(i); - has_f32 = true; + f32_count += 1; } else if (hlo->operand(i)->shape().element_type() == BF16) { - bf16_operands.push_back(i); - has_bf16 = true; + bf16_count += 1; } } if (hlo->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (hlo->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; } std::vector bf16_called_comps; for (auto* comp : hlo->called_computations()) { bool comp_has_bf16 = false; if (comp->root_instruction()->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (comp->root_instruction()->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; comp_has_bf16 = true; } for (auto* param : comp->parameter_instructions()) { if (param->shape().element_type() == F32) { - has_f32 = true; + f32_count += 1; } else if (param->shape().element_type() == BF16) { - has_bf16 = true; + bf16_count += 1; comp_has_bf16 = true; } } @@ -264,54 +279,69 @@ Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { } } - if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && - has_f32) { - // Resolve unsupported mixed precision. - // - // See if we can change everything to BF16. - if (hlo->called_computations().empty() && - hlo->shape().element_type() == BF16) { - bool can_use_bf16 = true; - for (int i : f32_operands) { - if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, - i) && - bfloat16_support_->SupportsBF16Operand(*hlo, i)) { - continue; - } - can_use_bf16 = false; - break; - } - if (can_use_bf16) { - for (int i : f32_operands) { - TF_RETURN_IF_ERROR( - InsertConvertBeforeOperand(hlo, i, BF16, computation_)); - } - return Status::OK(); - } - } - if (hlo->shape().element_type() == BF16) { - TF_RETURN_IF_ERROR( - ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); - } - for (int i : bf16_operands) { - TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); - } - return ConvertCalledComputations(hlo, bf16_called_comps); - } - - for (int i : bf16_operands) { - if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Resolve unsupported BF16 operands. + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Operand(*hlo, i)) { TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + bf16_count -= 1; + f32_count += 1; } } + // Resolve unsupported BF16 output. if (hlo->shape().element_type() == BF16 && !bfloat16_support_->SupportsBF16Output(*hlo)) { TF_RETURN_IF_ERROR( ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + bf16_count -= 1; + f32_count += 1; } - return Status::OK(); + // Resolve unsupported mixed precision after resolving unsupported BF16 + // operands and output, because the numbers of BF16 operands/output and F32 + // operands/output may have changed. + if (bfloat16_support_->SupportsMixedPrecisions(*hlo) || bf16_count == 0 || + f32_count == 0) { + return Status::OK(); + } + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16) { + continue; + } + if ((bfloat16_support_->EffectiveOperandPrecisionIsBF16(*hlo, i) || + bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i)) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + return ConvertCalledComputations(hlo, bf16_called_comps); } Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 66c3085842c4afe7ffc4d5891883e4cce9389d45..1afaefd9df9c5771fb9e134ae9050f3abb00ea4a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -41,13 +42,17 @@ class TestBFloat16Support : public BFloat16Support { hlo.opcode() == HloOpcode::kGetTupleElement) { return true; } + if (hlo.opcode() == HloOpcode::kDot) { + // Test that only the first operand of kDot supports BF16. + return operand_index == 0; + } return false; } bool SupportsBF16Output(const HloInstruction& hlo) const override { if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kSubtract || - hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kDot || hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { return true; } @@ -70,6 +75,10 @@ class BFloat16NormalizationTest : public HloTestBase { BFloat16Normalization normalization(&bfloat16_support_); StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); + + HloVerifier verifier(/*allow_mixed_precision=*/true); + EXPECT_IS_OK(verifier.Run(module).status()); + return result.ValueOrDie(); } }; @@ -166,7 +175,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); - Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {}); auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( @@ -245,4 +254,34 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); } +// Tests that the normalization should not cause unsupported mixed precision due +// to resolving unsupported BF16 operand. +TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(dot->shape().element_type(), F32); + EXPECT_EQ(dot->operand(0)->shape().element_type(), F32); + EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(dot->operand(1)->shape().element_type(), F32); + EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConvert); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc new file mode 100644 index 0000000000000000000000000000000000000000..c26d2feef584faeff013a602409cdd58c2d44a5a --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -0,0 +1,710 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_propagation.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +BFloat16Propagation::BFloat16Propagation( + const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + +void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( + HloInstruction* fusion) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) { + return; + } + + // We are depending on the fusion node itself having already been analyzed + // for whether it can output BF16 and this has been adjusted in the output + // shape, and now we're looking to update the interior of the fusion node to + // match the new output shape, as well as recursively process the whole fusion + // node even if the output shape was not modified. + auto root = fusion->fused_instructions_computation()->root_instruction(); + + // Adjust root's element types according to the fusion's output shape. + ShapeUtil::ForEachMutableSubshape( + root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() != F32) { + return; + } + if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() == + BF16) { + subshape->set_element_type(BF16); + changed_ = true; + VLOG(2) << "Fused root " << root->ToString() << " at shape index " + << index << " changed to BF16 precision for fusion " + << fusion->ToString(); + } + }); + + // Propagate BF16 in the fusion computation. + auto insts = + fusion->fused_instructions_computation()->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_mutation_pass_.insert( + fusion->fused_instructions_computation()); +} + +void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( + HloInstruction* while_hlo) { + CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile); + + // We are depending on the while node itself having already been analyzed for + // whether it can output BF16 and this has been adjusted in the output shape, + // and now we're looking to update the body and condition computations to + // match the new output shape, as well as recursively process the whole while + // node even if the output shape was not modified. + HloComputation* body = while_hlo->while_body(); + auto body_root = body->root_instruction(); + HloComputation* condition = while_hlo->while_condition(); + + ShapeUtil::ForEachMutableSubshape( + body_root->mutable_shape(), + [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() != F32) { + return; + } + if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() == + BF16) { + subshape->set_element_type(BF16); + changed_ = true; + VLOG(2) << "While body root " << body_root->ToString() + << " at shape index " << index + << " changed to BF16 precision for while " + << while_hlo->ToString(); + } + }); + + auto body_insts = body->MakeInstructionPostOrder(); + for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend(); + ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_mutation_pass_.insert(body); + + auto condition_insts = condition->MakeInstructionPostOrder(); + for (auto inst_it = condition_insts.rbegin(); + inst_it != condition_insts.rend(); ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } + computations_visited_in_mutation_pass_.insert(condition); +} + +bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const { + auto value_set = dataflow_->GetValueSet(&hlo, index); + for (const HloValue* value : value_set.values()) { + if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { + return false; + } + if (value->shape().element_type() == BF16) { + continue; + } + for (const HloUse& use : value->uses()) { + if (!ContainsKey(instructions_visited_in_mutation_pass_, + use.instruction)) { + // We don't know yet whether use.instruction will consume BF16 since it + // hasn't been visited. Although we visit instructions in reverse + // topological order, this is still possible because there may be + // unvisited instruction that alias the same buffer. In this case, we + // aggressively skip this use, and if this causes inconsistency (e.g., + // one use is in BF16 but another use is in F32), it will be resolved at + // the end of the BFloat16Propagation pass. + continue; + } + // Any visited user that can accept BF16 has already been updated if + // necessary, e.g., the output has been changed to BF16 if it propagates + // precision, or a called computation's parameters have been changed to + // BF16 for fusions or whiles. + if (use.instruction->opcode() == HloOpcode::kFusion) { + const auto* fused_parameter = + use.instruction->fused_parameter(use.operand_number); + if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index) + .element_type() != BF16) { + return false; + } + continue; + } else if (use.instruction->opcode() == HloOpcode::kWhile) { + const auto* cond_parameter = + use.instruction->while_condition()->parameter_instruction( + use.operand_number); + if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index) + .element_type() != BF16) { + return false; + } + const auto* body_parameter = + use.instruction->while_body()->parameter_instruction( + use.operand_number); + if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index) + .element_type() != BF16) { + return false; + } + continue; + } + if (bfloat16_support_->EffectiveOperandPrecisionIsBF16( + *use.instruction, use.operand_number)) { + continue; + } + // If the op propagates precision and it outputs a BF16, then it's OK to + // supply BF16 also as the input. In the backward mutation pass, the users + // shapes should have already been processed. + PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + user_output_type = ShapeUtil::GetSubshape( + ShapeUtil::GetSubshape(use.instruction->shape(), + {use.operand_number}), + use.operand_index) + .element_type(); + } else { + user_output_type = use.instruction->shape().element_type(); + } + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( + *use.instruction, use.operand_number) && + user_output_type == BF16) { + continue; + } + return false; + } + } + return true; +} + +void BFloat16Propagation::DetermineAndMutateInstructionPrecision( + HloInstruction* hlo, bool skip_parameters) { + // We handle any fusion computation or while body/condition after the + // instruction is handled, because we need to know the output shape of a + // fusion or while before propagating inside its computations. + bool postpone_processing_called_computations = false; + auto cleaner = tensorflow::gtl::MakeCleanup( + [this, hlo, &postpone_processing_called_computations] { + if (!postpone_processing_called_computations) { + if (hlo->opcode() == HloOpcode::kFusion) { + DetermineAndMutateFusionComputationPrecision(hlo); + } else if (hlo->opcode() == HloOpcode::kWhile) { + DetermineAndMutateWhileComputationsPrecision(hlo); + } + } + instructions_visited_in_mutation_pass_.insert(hlo); + }); + + if (hlo->opcode() == HloOpcode::kWhile && + (caller_counts_[hlo->while_condition()] > 1 || + caller_counts_[hlo->while_body()] > 1)) { + postpone_processing_called_computations = true; + return; + } + + // Do not change precision for instructions related to entry and exit of a + // computation, and control flow, because this pass might break the interfaces + // or assumptions for them. + if (hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kSend || // + hlo->opcode() == HloOpcode::kSendDone || // + hlo->opcode() == HloOpcode::kRecv || // + hlo->opcode() == HloOpcode::kRecvDone || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kConditional || // + (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { + return; + } + + // Prevent root instructions from having their output modified by recording + // all F32 output values as needing to stay as F32. + CHECK(hlo->parent() != nullptr); + if (hlo == hlo->parent()->root_instruction()) { + if (!hlo->parent()->IsFusionComputation()) { + ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape, + const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + // Since we use HloValues from the dataflow analysis, this can also + // affect HLO instructions beyond the root, e.g., if the root is a + // Tuple HLO, then its operands are also affected. + values_that_must_be_kept_as_f32_.insert(value); + } + }); + } + return; + } + + if (!ContainsKey(consider_using_bfloat16_, hlo)) { + return; + } + + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { + return; + } + + ShapeUtil::ForEachMutableSubshape( + hlo->mutable_shape(), + [hlo, this](Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() == F32 && + AllUsersConsumeBF16(*hlo, index)) { + subshape->set_element_type(BF16); + changed_ = true; + VLOG(2) << "HloInstruction output at shape index " << index + << " changed to BF16 precision: " << hlo->ToString(); + } + }); +} + +bool BFloat16Propagation::InstructionIsCandidateForBF16Output( + HloInstruction* hlo) { + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && + hlo->opcode() != HloOpcode::kTuple && + hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->shape().element_type() != BF16) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) || + !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) { + return false; + } + } + } + return true; +} + +void BFloat16Propagation::AdjustCalledComputationParameters( + HloInstruction* hlo) { + auto adjust_computation = + [this, hlo](HloComputation* computation, + tensorflow::gtl::ArraySlice operands) { + // Adjust parameters. + CHECK_EQ(operands.size(), computation->num_parameters()); + for (int64 i = 0; i < operands.size(); ++i) { + auto parameter = computation->parameter_instruction(i); + ShapeUtil::ForEachMutableSubshape( + parameter->mutable_shape(), + [this, i, hlo, &operands, parameter](Shape* subshape, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { + return; + } + PrimitiveType operand_type = + ShapeUtil::GetSubshape(operands[i]->shape(), index) + .element_type(); + if (subshape->element_type() == operand_type) { + return; + } + CHECK(operand_type == F32 || operand_type == BF16); + subshape->set_element_type(operand_type); + changed_ = true; + VLOG(2) << "Called computation parameter " + << parameter->ToString() << " at shape index " << index + << " adjusted to match operand in HLO " + << hlo->ToString(); + }); + } + }; + + switch (hlo->opcode()) { + case HloOpcode::kFusion: + adjust_computation(hlo->fused_instructions_computation(), + hlo->operands()); + break; + case HloOpcode::kWhile: + adjust_computation(hlo->while_condition(), hlo->operands()); + adjust_computation(hlo->while_body(), hlo->operands()); + break; + default: + break; + } +} + +void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { + auto adjust_computation = [this, hlo](HloComputation* computation, + const Shape& output_shape) { + // Adjust root. + HloInstruction* root = computation->root_instruction(); + ShapeUtil::ForEachMutableSubshape( + root->mutable_shape(), [this, hlo, root, &output_shape]( + Shape* subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { + return; + } + const PrimitiveType output_type = + ShapeUtil::GetSubshape(output_shape, index).element_type(); + if (subshape->element_type() == output_type) { + return; + } + CHECK(output_type == F32 || output_type == BF16); + subshape->set_element_type(output_type); + // It's possible that output_type is F32, but the root instruction's + // type is BF16; e.g., a fusion node's output was changed to BF16 + // initially but then adjusted back to F32, and the fusion computation + // is now being adjusted after the fusion node. + if (output_type == F32) { + for (const auto* value : + dataflow_->GetValueSet(root, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order so that called computation will be + // processed later. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the + // correctness of the adjustment for HLOs that will be + // processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + changed_ = true; + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index + << " adjusted to match output shape of " << hlo->ToString(); + }); + }; + + switch (hlo->opcode()) { + case HloOpcode::kFusion: + adjust_computation(hlo->fused_instructions_computation(), hlo->shape()); + break; + case HloOpcode::kWhile: + adjust_computation(hlo->while_condition(), hlo->shape()); + adjust_computation(hlo->while_body(), hlo->shape()); + break; + default: + break; + } +} + +bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited_computations) { + bool parameter_changed = false; + auto insts = computation->MakeInstructionPostOrder(); + // Do the adjustment on each instruction in the computation in reverse + // topological order. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + auto adjust_hlo_output = [this, hlo, ¶meter_changed]( + Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() != F32 && subshape->element_type() != BF16) { + return; + } + PrimitiveType type = BF16; + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + if (value->shape().element_type() == BF16) { + continue; + } + CHECK_EQ(value->shape().element_type(), F32); + type = F32; + break; + } + // It's possible that a user has been changed from BF16 to F32 + // during this final adjustment pass, so we need to check + // AllUsersConsumeBF16() again. + if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { + type = F32; + } + if (type == F32) { + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the correctness + // of the adjustment for HLOs that will be processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + if (type != subshape->element_type()) { + subshape->set_element_type(type); + VLOG(2) << "HloInstruction output at shape index " << index + << " adjusted to " << *subshape << ": " << hlo->ToString(); + if (hlo->opcode() == HloOpcode::kParameter) { + parameter_changed = true; + } + } + }; + ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + if (hlo->opcode() == HloOpcode::kWhile) { + // We need to run on the while body and condition repeatedly until a fixed + // point is reached, i.e., the parameters do not change any more. We may + // need more than one iteration because the while input and output alias + // each other, so changing one input parameter requires changing the + // corresponding output element and thus may transitively require changing + // another input parameter. A fixed point will be reached because the + // parameters can only be changed from BF16 to F32, not the other way + // around. + tensorflow::gtl::FlatSet visited_in_while; + while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), + &visited_in_while) || + ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), + &visited_in_while)) { + visited_in_while.clear(); + ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), + adjust_hlo_output); + AdjustCalledComputationRoot(hlo); + } + visited_computations->insert(visited_in_while.begin(), + visited_in_while.end()); + } + } + // Now adjust parameters of called computations. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + AdjustCalledComputationParameters(*inst_it); + } + return parameter_changed; +} + +Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( + HloModule* module) { + std::list computations_topological_order = + module->MakeComputationPostOrder(); + tensorflow::gtl::FlatSet resolved; + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + if (ContainsKey(resolved, *comp_it)) { + continue; + } + ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved); + } + + // We could have changed a fusion computation's root shape to have a different + // precision than the fusion node's output, if the fusion root does not + // define a buffer (e.g., a tuple). Now we add conversions after such fusion + // roots to make them match the fusion output. If the fusion output is a + // (possibly nested) tuple, we first create get-tuple-elements, then convert + // the unmatching leaf nodes, and finally create a new tuple as the fusion + // computation's root. If tuples and get-tuple-elements are created, we will + // run tuple simplifier and dead code elimination at the end (dead code is not + // allowed in fusion computation). E.g., + // + // (1) (2) (3) + // a b a b a b + // |\ | |\ | |\ | + // \ add -> |add -> | add + // \ | \ | convert | + // tuple tuple \ | + // / \ tuple + // gte gte + // | | + // convert | + // \ / + // tuple + // (1) a is F32 but tuple is BF16 + // (2) after adding conversion + // (3) after tuple simplifier and DCE. + bool needs_tuple_simplifier = false; + for (auto computation : computations_topological_order) { + auto insts = computation->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + if (hlo->opcode() != HloOpcode::kFusion) { + continue; + } + auto fusion_computation = hlo->fused_instructions_computation(); + auto fusion_root = fusion_computation->root_instruction(); + if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) { + continue; + } + ShapeTree converted_outputs(hlo->shape()); + // Iterate through nodes in the shape tree in pre-order and initialize + // each non-root node with a corresponding get-tuple-element. For a leaf + // node, if its shape does not match the fusion output, create a + // conversion node to overwrite the node value. + for (auto it = converted_outputs.begin(); it != converted_outputs.end(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (output_index.empty()) { + output = fusion_root; + } else { + ShapeIndex parent_index = output_index; + parent_index.pop_back(); + output = fusion_computation->AddInstruction( + HloInstruction::CreateGetTupleElement( + subshape, converted_outputs.element(parent_index), + output_index.back())); + } + if (ShapeUtil::IsTuple(subshape)) { + continue; + } + if (!ShapeUtil::Compatible( + subshape, + ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) { + output = fusion_computation->AddInstruction( + HloInstruction::CreateConvert(subshape, output)); + } + } + // Iterate through nodes in the shape tree in reverse pre-order and create + // a tuple instruction for each non-leaf node where the elements are the + // values of its child nodes. + for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape& subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (!ShapeUtil::IsTuple(subshape)) { + continue; + } + std::vector elements( + ShapeUtil::TupleElementCount(subshape)); + ShapeIndex child_index = output_index; + for (int64 i = 0; i < elements.size(); ++i) { + child_index.push_back(i); + elements[i] = converted_outputs.element(child_index); + child_index.pop_back(); + } + output = fusion_computation->AddInstruction( + HloInstruction::CreateTuple(elements)); + } + fusion_computation->set_root_instruction(converted_outputs.element({})); + needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); + } + } + + // We may have converted some constants from F32 to BF16, so adjust the + // constant literals in such cases. We do this here instead of when the + // constant node's is changed because 1) the HloInstruction interface does not + // allow resetting the literal so we have to create a new kConstant + // instruction to replace the old one, which invalidates dataflow analysis, + // and 2) it's possible that a kConstant's output gets changed to BF16 at the + // beginning but later on adjusted back to F32, so converting literals here + // can avoid repeated conversions. + // + // TODO(b/73833576): Consider resetting literal in HloInstruction. + bool needs_dce = needs_tuple_simplifier; + for (auto computation : computations_topological_order) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConstant) { + continue; + } + if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { + TF_ASSIGN_OR_RETURN( + auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape(), + /*round_f32_to_bf16=*/true)); + auto new_constant = computation->AddInstruction( + HloInstruction::CreateConstant(std::move(converted_literal))); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); + needs_dce = true; + } + } + } + + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + } + if (needs_dce) { + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { + for (auto computation : module->computations()) { + for (auto hlo : computation->MakeInstructionPostOrder()) { + if (hlo->opcode() != HloOpcode::kConvert) { + continue; + } + auto source = hlo->mutable_operand(0); + if (!ShapeUtil::Equal(source->shape(), hlo->shape())) { + continue; + } + const bool is_root = hlo == computation->root_instruction(); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source)); + if (is_root) { + computation->set_root_instruction(source); + } + TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo)); + } + } + return Status::OK(); +} + +// The algorithm first does a forward pass (parameters to root) to determine a +// set of instructions to consider using bfloat16, then does a backward pass to +// determine the precisions of those instructions according to the need of +// their users. +StatusOr BFloat16Propagation::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); + + std::list computations_topological_order = + module->MakeComputationPostOrder(); + // The first step is a forward pass (parameters to root), where we determine + // the potential candidate instructions to use bfloat16 in the outputs that + // are not likely to cause overhead from extra explicit conversions. This is + // done forwardly because we determine whether an HLO is a candidate partially + // based on whether its operands are candidates. + for (auto computation : computations_topological_order) { + for (auto inst : computation->MakeInstructionPostOrder()) { + if (InstructionIsCandidateForBF16Output(inst)) { + consider_using_bfloat16_.insert(inst); + } + } + } + + // The second step is a backward pass (root to parameters), where we modify + // the precisions of the instructions identified in the first step when + // feasible. This is done backwardly because we determine the precision of an + // HLO's output based on how it is later used. + // + // The precision of an instruction is determined by its users, so we do the + // propagation in reverse topological order. + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + if ((*comp_it)->IsFusionComputation()) { + // Fusion computations are handled when visiting the fusion instruction. + continue; + } + auto insts = (*comp_it)->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, + /*skip_parameters=*/true); + } + } + + if (!changed_) { + return false; + } + + // It's possible that an instruction does not define a buffer, but the + // defining instruction's shape has changed. So we need to adjust the output + // shapes of instructions according to the HLO values they refer to. + TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + + // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> + // BF16), so we remove them now. + TF_RETURN_IF_ERROR(RemoveNoopConversions(module)); + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h new file mode 100644 index 0000000000000000000000000000000000000000..1744e9db90aeff269daa91eb68a1d61bb0fc3035 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// HLO pass which reduces the precision of some HLO instructions to BF16 +// according to the backend-specific BFloat16Support rule provided by the +// caller. +// +// This pass can be used to reduce instruction precision without affecting the +// numerical accuracy of the module, i.e., the final output of the module would +// be bitwise identical to that without this pass; this is possible if the +// backend already reduces precision to BF16 on some HLO instructions. +// +// This pass will not modify the signature of a computation, unless it is a +// fusion computation or its only caller is a while. +// +// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, +// which has two issues: +// +// 1) It does not guarantee to respect the passed-in BFloat16Support +// specification in terms of mixed precision, so the backend may not support an +// HLO that has mixed precision produced by this pass. To address this issue, +// run BFloat16Normalization with the same BFloat16Support after this pass. +// +// 2) In general, mixed precision may break the assumptions of some other HLO +// passes even if the specific backend supports the individual HLOs. Such +// assumptions include that there are no HLOs using mixed precision, or that the +// precision of an HLO's output is determined by its inputs. It should be used +// at the end of the HLO optimization pipeline but before +// BFloat16ConversionFolding. If other passes are needed after this pass, run +// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this +// pass. +class BFloat16Propagation : public HloPassInterface { + public: + explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); + + ~BFloat16Propagation() override = default; + + tensorflow::StringPiece name() const override { + return "bfloat16-propagation"; + } + + // Runs the pass on the given module. Returns whether the module was changed + // (precision reductions were added). + StatusOr Run(HloModule* module) override; + + private: + // *************************** + // Function called and state produced by the forward analysis pass (from + // parameters to root) that determines the candidate HLOs to use BF16 outputs. + + // Determines whether we should consider changing the precision of the given + // instruction in the forward pass. + bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); + + // The set of instructions to consider using bfloat16, computed in the forward + // pass. + tensorflow::gtl::FlatSet consider_using_bfloat16_; + + // *************************** + // Functions called and state produced by the backward mutation pass (from + // root to parameters). + + // Determines the precision for the given instruction in the mutation pass. + void DetermineAndMutateInstructionPrecision(HloInstruction* hlo, + bool skip_parameters); + + // Special handling in the mutation pass for fusion computations. + // + // Precondition: hlo->opcode() == kFusion + void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion); + + // Special handling in the mutation pass for while computations. + // + // Precondition: hlo->opcode() == kWhile + void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo); + + // The set of HloInstructions that have been visited in the mutation pass. + tensorflow::gtl::FlatSet + instructions_visited_in_mutation_pass_; + + // The set of HloComputations that have been visited in the mutation pass. + tensorflow::gtl::FlatSet + computations_visited_in_mutation_pass_; + + // *************************** + // Functions called by the final inconsistency resolving pass. + + // Adjusts the output shapes of HloInstructions such that if two + // HloInstructions have aliasing buffers in their outputs, they must have the + // same precision. + Status ResolveInconsistencyOfAliasingBuffers(HloModule* module); + + // Resolves inconsistency of aliasing buffers for the given computation, and + // recursively runs on a while instruction's condition and body until a fixed + // point is reached. + bool ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + tensorflow::gtl::FlatSet* visited_computations); + + // Makes the parameters of called computations match how they are called by + // the given HLO. + void AdjustCalledComputationParameters(HloInstruction* hlo); + + // Makes the root instructions of called computations match how they are used + // by the given HLO. + void AdjustCalledComputationRoot(HloInstruction* hlo); + + // *************************** + // Removes no-op conversions (same source and target shapes) that can be + // produced this pass. + Status RemoveNoopConversions(HloModule* module); + + // *************************** + // Functions called and state used by two or more passes. + + // Returns whether all uses of the given HloInstruction can consume BF16 + // input. + bool AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const; + + // The set of F32 HLO values that must be kept in F32. + tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; + + // Mapping from each HloComputation to the number of callers to it in the + // module. Populated at the beginning of this pass. + tensorflow::gtl::FlatMap caller_counts_; + + const BFloat16Support* bfloat16_support_; + std::unique_ptr dataflow_; + + bool changed_ = false; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88f83014164ff726a11e45e762b9c082cf12720d --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -0,0 +1,660 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// A class specifying the BF16 support used to test the propagation pass. It +// specifies that BF16 and mixed precision are supported in all HloInstructions, +// and that kDot reduces its operands precision to BF16. +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return true; + } + + bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const override { + return hlo.opcode() == HloOpcode::kDot; + } +}; + +class BFloat16PropagationTest : public HloTestBase { + protected: + // Runs the propagation pass on the given module, and returns whether the + // module is changed after this pass. + bool PropagatePrecision(HloModule* module) { + TestBFloat16Support bfloat16_support; + BFloat16Propagation propagation(&bfloat16_support); + StatusOr result = propagation.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } + + // Returns whether the given HloInstruction's output element type is BF16 or + // the only use of it is converting to BF16. + bool OutputsBF16(const HloInstruction* inst) { + if (inst->shape().element_type() == BF16) { + return true; + } + return inst->user_count() == 1 && + inst->users()[0]->opcode() == HloOpcode::kConvert && + inst->users()[0]->shape().element_type() == BF16; + } +}; + +// Tests that BF16 can propagate through select over non-tuple buffers, but not +// through add where reducing operand precision can affect the result. +TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* sel = builder.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(sel)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); + EXPECT_FALSE(OutputsBF16(c)); +} + +// Tests that if a constant is converted to BF16 then its literal must also be +// converted. +TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Array2D array_a(4, 4); + array_a.FillUnique(1.0f); + Array2D array_b(4, 4); + array_b.FillUnique(10.0f); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromArray(array_a))); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromArray(array_b))); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(dot->operand(0))); + EXPECT_TRUE(OutputsBF16(dot->operand(1))); + EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); + EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); + LiteralTestUtil::ExpectEqual( + dot->operand(0)->literal(), + *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); + LiteralTestUtil::ExpectEqual( + dot->operand(1)->literal(), + *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); +} + +// Tests that BF16 can be propagated through nested tuples. +TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); + + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose})); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1)); + HloInstruction* rhs = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + add0->shape(), + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple0->shape(), tuple1, 0)), + 0)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + + HloInstruction* output_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), output_tuple); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(add2)); +} + +// Tests that even if an instruction does not define a buffer in its output, its +// shape must match the defining instruction. +TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); + + HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); + + // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_TRUE(OutputsBF16(lhs)); + // rhs is a get-tuple-element, which does not define a buffer, but its shape + // should also be adjusted accordingly. + EXPECT_TRUE(OutputsBF16(rhs)); +} + +// Tests that a non-fusion computation's root should not be changed. +TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_FALSE(OutputsBF16(add)); +} + +// Tests that BF16 is propagated properly through fused computations. +TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f0 = HloComputation::Builder("fusion0"); + HloInstruction* a_f0 = + builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f0 = + builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* tuple_f0 = + builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0})); + auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build()); + auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f0)); + + auto builder_f1 = HloComputation::Builder("fusion1"); + HloInstruction* p_f1 = builder_f1.AddInstruction( + HloInstruction::CreateParameter(0, tuple_f0->shape(), "param")); + HloInstruction* a_f1 = builder_f1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); + HloInstruction* b_f1 = builder_f1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); + HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); + auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( + dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), fusion1); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(a_f0)); + EXPECT_TRUE(OutputsBF16(b_f0)); + EXPECT_TRUE(OutputsBF16(a_f1)); + EXPECT_TRUE(OutputsBF16(b_f1)); +} + +// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion +// outputs are only used by a dot, and 3) one element of the tuple is used by +// an add in the fusion computation, then the propagation pass should create a +// convert in the fusion computation to keep the add's operand in F32 but change +// the fusion output to BF16. E.g., the following fusion computation +// (F32, F32) fusion_computation(F32 a, F32 b) +// = tuple(F32 a, F32 add(F32 a, F32 b)) +// will be changed to +// (BF16, BF16) fusion_computation(F32 a, F32 b) +// = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) +TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion0"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* tuple_f = + builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f)); + + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(gte0)); + EXPECT_TRUE(OutputsBF16(gte1)); + EXPECT_FALSE(OutputsBF16(a_f)); + EXPECT_FALSE(OutputsBF16(b_f)); + EXPECT_TRUE(OutputsBF16(add_f)); + auto new_fusion_root = comp_f->root_instruction(); + EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(new_fusion_root->operand(1), add_f); + EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0))); +} + +// A select over tuples does not define the leaf buffers, so the types in +// on_true and on_false must match, so that as long as one of them is F32, the +// other must be F32 as well. +TEST_F(BFloat16PropagationTest, SelectOverTuples) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(PRED, {}), "pred")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({param, add0})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); + HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( + tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, sel, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, sel, 1)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(gte0)); + EXPECT_FALSE(OutputsBF16(gte1)); + EXPECT_TRUE(OutputsBF16(xpose)); +} + +// Tests that BF16 is propagated properly through while computations. +TEST_F(BFloat16PropagationTest, PropagateThroughWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, tuple->shape(), "cond_param")); + auto cond_lhs = builder_cond.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond_param, 0)); + auto cond_rhs = builder_cond.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond_param, 1)); + // This add should prevent RHS from using BF16 + auto cond_add_rhs = builder_cond.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); + auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, tuple->shape(), "body_param")); + auto body_lhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + auto body_rhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 1)); + auto body_dot = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + builder_body.AddInstruction( + HloInstruction::CreateTuple({body_dot, body_rhs})); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(tuple->shape(), cond, body, tuple)); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(lhs)); + EXPECT_FALSE(OutputsBF16(rhs)); + EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_lhs)); + EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_TRUE(OutputsBF16(cond_lhs)); + EXPECT_FALSE(OutputsBF16(cond_rhs)); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(add1)); +} + +// Tests that BF16 is not propagated through multiple whiles that invoke the +// same computation as long as one while prevents the propagation. +TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* add3 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({add2, add3})); + + // Condition computation for the first while. + auto builder_cond0 = HloComputation::Builder("cond0"); + auto cond0_param = builder_cond0.AddInstruction( + HloInstruction::CreateParameter(0, tuple0->shape(), "cond0_param")); + auto cond0_lhs = builder_cond0.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond0_param, 0)); + auto cond0_rhs = builder_cond0.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond0_param, 1)); + // This add should prevent RHS from using BF16 + auto cond0_add_rhs = + builder_cond0.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); + auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + builder_cond0.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); + + // Condition computation for the second while. + auto builder_cond1 = HloComputation::Builder("cond1"); + auto cond1_param = builder_cond1.AddInstruction( + HloInstruction::CreateParameter(0, tuple1->shape(), "cond1_param")); + auto cond1_lhs = builder_cond1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond1_param, 0)); + auto cond1_rhs = builder_cond1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, cond1_param, 1)); + // This add should prevent LHS from using BF16 + auto cond1_add_lhs = + builder_cond1.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); + auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + builder_cond1.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); + + // Body computation shared by both whiles. + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, tuple0->shape(), "body_param")); + auto body_lhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + auto body_rhs = builder_body.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 1)); + auto body_dot = builder_body.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + builder_body.AddInstruction( + HloInstruction::CreateTuple({body_dot, body_rhs})); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple0->shape(), cond0, body, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); + + auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_FALSE(OutputsBF16(body_dot)); + EXPECT_FALSE(OutputsBF16(body_rhs)); + EXPECT_FALSE(OutputsBF16(body_lhs)); + EXPECT_FALSE(OutputsBF16(cond0_lhs)); + EXPECT_FALSE(OutputsBF16(cond0_rhs)); + EXPECT_FALSE(OutputsBF16(cond1_lhs)); + EXPECT_FALSE(OutputsBF16(cond1_rhs)); + EXPECT_TRUE(OutputsBF16(cond0_add_rhs)); + EXPECT_TRUE(OutputsBF16(cond1_add_lhs)); + EXPECT_EQ(computation->root_instruction(), dot); +} + +// Tests that if this pass turns an F32 -> BF16 conversion into a no-op (BF16 -> +// BF16 conversion), then it will remove that conversion. +TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "param")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, param, param)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 1)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte1)); + HloInstruction* add2 = builder.AddInstruction(HloInstruction::CreateBinary( + bf16_shape, HloOpcode::kAdd, convert0, convert1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), add2); + EXPECT_EQ(add2->operand(0), gte0); + EXPECT_EQ(add2->operand(1), gte1); + EXPECT_EQ(gte0->shape().element_type(), BF16); + EXPECT_EQ(gte1->shape().element_type(), BF16); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 3fd9e24601f27633c8063e4574c7c4f91f30dcff..07b4b14b5ec1bdbc01345091105df69368b0b2fb 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -79,6 +79,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kBroadcast: case HloOpcode::kClamp: case HloOpcode::kConcatenate: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h index 29f662d22b4e5486662a1387407d41e0fd2ed1b3..82c2745f444e4f9c544c78cb36dafc11f678518a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.h +++ b/tensorflow/compiler/xla/service/bfloat16_support.h @@ -39,7 +39,7 @@ class BFloat16Support { // precisions (BF16 and F32). virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const; - // Returns whether the given HLO inherits its BF16 operand precision at the + // Returns whether the given HLO preserves its BF16 operand precision at the // given index, so even if the output is F32, elements in the output that // depend on the BF16 operand will still have BF16 effective precision even if // they have F32 format. Similarly, this also means if the output is BF16 then diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index b1e693da9d5af4babe619b8796007f2da318f6a8..dbe45e932cdeed00e959355d5b3199d2e858148f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -48,6 +48,183 @@ using ::tensorflow::strings::HumanReadableNumBytes; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrAppend; +namespace { + +template +string ColocatedBufferSetsToString(const T& container, const char* title) { + string result; + StrAppend(&result, title, "\n"); + for (const auto& it : container) { + StrAppend(&result, "\t", it->ToString(), "\n"); + } + return result; +} + +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations) { + // Create a worklist of computations paired with whether the allocation must + // be thread-local. + std::deque> worklist; + worklist.push_back(std::make_pair(module->entry_computation(), + /*is_thread_local*/ false)); + + // Sets for quickly checking membership. Computations are returned in vectors + // for stable iteration. + FlatSet thread_local_set; + FlatSet global_set; + + while (!worklist.empty()) { + auto worklist_front = worklist.front(); + worklist.pop_front(); + const HloComputation* computation = worklist_front.first; + bool is_thread_local = worklist_front.second; + bool in_thread_local_set = thread_local_set.count(computation) > 0; + bool in_global_set = global_set.count(computation) > 0; + + // If the computation has already been added to the respective set, then + // nothing to do. + if ((is_thread_local && in_thread_local_set) || + (!is_thread_local && in_global_set)) { + continue; + } + + // If the computation has already been added to the other set this is an + // error condition because the global call to the computation (eg, + // while/call) may return a reference to one of the thread-local buffers to + // the calling computation which will become a dangling reference when the + // thread-local is deallocated with the call return. + if ((is_thread_local && in_global_set) || + (!is_thread_local && in_thread_local_set)) { + return InvalidArgument( + "computation %s has conflicting allocation requirements (global " + "and thread-local)", + computation->name().c_str()); + } + + if (is_thread_local) { + thread_local_set.insert(computation); + } else { + global_set.insert(computation); + } + + for (auto* instruction : computation->instructions()) { + for (HloComputation* subcomputation : + instruction->called_computations()) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kWhile: + // Call and while must be called from a computation with global + // allocations as they may return references to buffers inside the + // called computation which cannot be thread-local. + if (is_thread_local) { + return InvalidArgument( + "computation %s cannot contain call/while op because it " + "requires thread-local buffer allocations", + computation->name().c_str()); + } + worklist.push_back(std::make_pair(subcomputation, + false)); // Not thread local. + break; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + // Map/reduce etc computations are always thread-local. + worklist.push_back(std::make_pair(subcomputation, + true)); // Thread local. + break; + default: + return InternalError( + "Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode()).c_str()); + } + } + } + } + + // Add the computations to the vectors in post order. + for (auto* computation : module->MakeComputationPostOrder()) { + if (thread_local_set.count(computation) > 0) { + thread_local_computations->push_back(computation); + } else if (global_set.count(computation) > 0) { + global_computations->push_back(computation); + } + // If the computation is not reachable from the entry computation, then it + // will not appear in either thread_local_set or global_set. We don't bother + // assigning buffers for these. + } + return Status::OK(); +} + +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace + size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); @@ -115,6 +292,112 @@ BufferAllocationProto BufferAllocation::ToProto() const { return proto; } +std::pair> +BufferAllocation::ComputePeakMemoryLogicalBuffers() const { + if (HeapTraces().empty()) { + // Just return the largest LogicalBuffer in the allocation. + const LogicalBuffer* largest_buffer = nullptr; + int64 largest_size = 0; + for (const auto& pair : assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + int64 size = pair.second.size; + if (largest_buffer == nullptr) { + largest_buffer = buffer; + largest_size = size; + continue; + } + // Tie-break with LogicalBuffer::Id so the return value is stable relative + // to changing addresses. + if (size > largest_size || + ((size == largest_size) && (largest_buffer->id() > buffer->id()))) { + largest_buffer = buffer; + largest_size = size; + } + } + CHECK(largest_buffer != nullptr) + << "No logical buffers in allocation: " << ToString(); + return {largest_size, {largest_buffer}}; + } + + // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical + // buffers in this allocation. + tensorflow::gtl::FlatMap + id_to_buffer; + tensorflow::gtl::FlatMap buffer_sizes; + for (const auto& pair : assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + const OffsetSize& offset_size = pair.second; + id_to_buffer[buffer->id()] = buffer; + buffer_sizes[buffer] = offset_size.size; + } + + // Returns how much the given event increases the total size of live + // buffers. Can be negative. + auto memory_delta = [this, &id_to_buffer, &buffer_sizes]( + const HeapSimulatorTrace::Event& event) -> int64 { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + const int64 buffer_size = buffer_sizes.at(buffer); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + return buffer_size; + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Sharing a buffer does not change the live set size for the purposes of + // the heap simulator. Even though the shared-with buffer may be smaller, + // the entire allocation remains live. + return 0; + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + return -1 * buffer_size; + } + LOG(FATAL) << "Unknown event kind: " << event.kind(); + }; + + int64 total_max_live_size = 0; + std::vector live_buffers_vector; + for (const HeapSimulatorTrace& heap_trace : HeapTraces()) { + // First compute the size of the maximal live set. + int64 max_live_size = 0; + int64 live_size = 0; + for (const auto& event : heap_trace.events()) { + live_size += memory_delta(event); + if (max_live_size < live_size) { + max_live_size = live_size; + } + } + + // Next gather the set of logical buffers live at the earliest point of + // maximal live set size. + tensorflow::gtl::FlatSet live_buffers; + live_size = 0; + for (const auto& event : heap_trace.events()) { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + InsertOrDie(&live_buffers, buffer); + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Nothing to do. + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + CHECK(ContainsKey(live_buffers, buffer)); + live_buffers.erase(buffer); + } + + live_size += memory_delta(event); + if (live_size == max_live_size) { + break; + } + } + CHECK_EQ(live_size, max_live_size); + total_max_live_size += max_live_size; + + live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), + live_buffers.end()); + } + + // Stabily sort the live buffers. + std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + return {total_max_live_size, live_buffers_vector}; +} + string BufferAllocation::ToString() const { string output; Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); @@ -348,6 +631,7 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // Combines allocations of temporary buffers of the same color into one big // BufferAllocation. void BufferAssignment::CombineTempAllocations() { + VLOG(1) << "CombineTempAllocations()"; FlatMap combined_allocation_map; @@ -369,11 +653,16 @@ void BufferAssignment::CombineTempAllocations() { if (combined_it == combined_allocation_map.end()) { // We have found the first temp allocation of this color. Collect // the other temp allocations of the same color into it. + VLOG(1) << "Combined temp allocation for color " << color + << " is: " << temp_allocation; combined_allocation_map.emplace(color, temp_allocation); continue; } auto* combined_allocation = &combined_it->second; + VLOG(1) << "Combined allocation absorbing temp allocation: " + << temp_allocation; + // Each temp allocation is placed end-to-end, accounting for alignment. // The offset of each buffer in the combined allocation is computed from // the base offset of the allocation. @@ -387,6 +676,10 @@ void BufferAssignment::CombineTempAllocations() { const int64 size = buffer_offset_size.second.size; combined_allocation->AddAssignment(*buffer, base + offset, size); } + if (!temp_allocation.HeapTraces().empty()) { + CHECK_EQ(temp_allocation.HeapTraces().size(), 1); + combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front()); + } } // Replace all existing temporary allocations with the new combined // allocations. @@ -516,123 +809,13 @@ BufferAssignmentProto BufferAssignment::ToProto() const { for (const BufferAllocation& allocation : Allocations()) { BufferAllocationProto proto_allocation = allocation.ToProto(); proto.add_buffer_allocations()->Swap(&proto_allocation); - } - for (const HeapSimulatorTrace& trace : heap_simulator_traces_) { - *proto.add_heap_simulator_traces() = trace; - } - return proto; -} - -namespace { - -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). -Status GatherComputationsByAllocationType( - const HloModule* module, - std::vector* thread_local_computations, - std::vector* global_computations) { - // Create a worklist of computations paired with whether the allocation must - // be thread-local. - std::deque> worklist; - worklist.push_back(std::make_pair(module->entry_computation(), - /*is_thread_local*/ false)); - - // Sets for quickly checking membership. Computations are returned in vectors - // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; - - while (!worklist.empty()) { - auto worklist_front = worklist.front(); - worklist.pop_front(); - const HloComputation* computation = worklist_front.first; - bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; - - // If the computation has already been added to the respective set, then - // nothing to do. - if ((is_thread_local && in_thread_local_set) || - (!is_thread_local && in_global_set)) { - continue; + for (const HeapSimulatorTrace& heap_trace : allocation.HeapTraces()) { + *proto.add_heap_simulator_traces() = heap_trace; } - - // If the computation has already been added to the other set this is an - // error condition because the global call to the computation (eg, - // while/call) may return a reference to one of the thread-local buffers to - // the calling computation which will become a dangling reference when the - // thread-local is deallocated with the call return. - if ((is_thread_local && in_global_set) || - (!is_thread_local && in_thread_local_set)) { - return InvalidArgument( - "computation %s has conflicting allocation requirements (global " - "and thread-local)", - computation->name().c_str()); - } - - if (is_thread_local) { - thread_local_set.insert(computation); - } else { - global_set.insert(computation); - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* subcomputation : - instruction->called_computations()) { - switch (instruction->opcode()) { - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kWhile: - // Call and while must be called from a computation with global - // allocations as they may return references to buffers inside the - // called computation which cannot be thread-local. - if (is_thread_local) { - return InvalidArgument( - "computation %s cannot contain call/while op because it " - "requires thread-local buffer allocations", - computation->name().c_str()); - } - worklist.push_back(std::make_pair(subcomputation, - false)); // Not thread local. - break; - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kFusion: - // Map/reduce etc computations are always thread-local. - worklist.push_back(std::make_pair(subcomputation, - true)); // Thread local. - break; - default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); - } - } - } - } - - // Add the computations to the vectors in post order. - for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { - thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { - global_computations->push_back(computation); - } - // If the computation is not reachable from the entry computation, then it - // will not appear in either thread_local_set or global_set. We don't bother - // assigning buffers for these. } - return Status::OK(); + return proto; } -} // namespace - /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, @@ -1064,7 +1247,8 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } - assignment->heap_simulator_traces_.push_back(result.debug_trace); + VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString(); + allocation->AddHeapTrace(result.debug_trace); } // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining @@ -1085,7 +1269,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( if (colocated_set.empty()) { return; } - + VLOG(5) << ColocatedBufferSetsToString(colocated_set, + "Adding colocated buffer set"); // Find existing sets that overlap with at least one buffer from the // colocated_set. The resulting 'overlap_set_indices' will have at most // colocated_buffer_sets->size() entries, and will be in increasing order. @@ -1093,6 +1278,10 @@ void BufferAssigner::AddSetToColocatedBufferSets( for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + VLOG(5) << "Found overlap with existing set on buffer " + << buffer->ToString() << "\n" + << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], + "Overlapping set"); overlap_set_indices.push_back(index); break; } @@ -1104,6 +1293,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( colocated_buffer_sets->emplace_back(); colocated_buffer_sets->back().insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << "No overlap found, new group created"; return; } @@ -1115,6 +1305,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( first->insert(overlap_set.begin(), overlap_set.end()); } first->insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << ColocatedBufferSetsToString( + *first, "Result of the colocated buffer set merging"); // Remove overlap sets that we just merged. The offset accounts for the fact // that as elements are erased, the indices need to be adjusted. Keep in mind @@ -1125,67 +1317,6 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } -namespace { - -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - std::vector BufferAssigner::MergeColocatedBufferSets( const std::vector& colocated_buffer_sets, @@ -1208,26 +1339,35 @@ BufferAssigner::MergeColocatedBufferSets( auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, &buffer_size, &is_entry_parameter](int64 i, int64 j) { - for (auto& buffer_a : colocated_buffer_sets[i]) { - for (auto& buffer_b : colocated_buffer_sets[j]) { - // Do not merge if the set includes live outs or entry parameters. - if ((buffer_liveness.MaybeLiveOut(*buffer_a) && - is_entry_parameter(*buffer_b)) || - (buffer_liveness.MaybeLiveOut(*buffer_b) && - is_entry_parameter(*buffer_a))) { + // Do not merge if one of the sets includes live outs or entry parameters. + for (int64 key : {i, j}) { + for (auto& buffer : colocated_buffer_sets[key]) { + if (buffer_liveness.MaybeLiveOut(*buffer) || + is_entry_parameter(*buffer)) { return true; } - // Do not merge if the buffers interfere with each other. + } + } + + // Colocated sets satisfy the invariant that all buffers within a set have + // the same size. That means we need to check whether the size is the same + // between the two sets, but also that it's enough to look at just one + // buffer within each set. + if (buffer_size(**colocated_buffer_sets[i].begin()) != + buffer_size(**colocated_buffer_sets[j].begin())) { + return true; + } + + // Do not merge if some pair of buffers interferes with each other. + for (auto& buffer_a : colocated_buffer_sets[i]) { + for (auto& buffer_b : colocated_buffer_sets[j]) { if (buffer_a->id() != buffer_b->id() && buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { return true; } - // Do not merge if the buffer sizes are different. - if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) { - return true; - } } } + return false; }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 6b7fd0014d103ef0617afcc5cb3f663554a01aa4..3086d0e2ca0026547134285b8ceb357390fc7ece 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -192,6 +192,37 @@ class BufferAllocation { !is_thread_local(); } + // Add a heap trace which was used to assign slices to logical buffers in this + // allocation. A single BufferAllocation may include multiple heap traces + // in the case of the temporary block where there is a heap trace per + // computation. + void AddHeapTrace(const HeapSimulatorTrace& heap_trace) { + heap_traces_.push_back(heap_trace); + } + + // Return the set of heap traces used to assign slices to logical buffers in + // this allocation. + const std::vector HeapTraces() const { + return heap_traces_; + } + + // Compute and return the LogicalBuffers which are live at the point of peak + // memory usage for the given allocation. The point of peak memory usage is + // the point at which the total size of all live logical buffers is + // maximal. If peak memory is reached at multiple points, the set of logical + // buffers live at the earliest maximal point is returned. The vector is + // stabily asserted by LogicalBuffer::Index. + // + // The return value is a pair of total size of the logical buffers at peak, + // and the buffers themselves. + std::pair> + ComputePeakMemoryLogicalBuffers() const; + + // Get the number of bytes lost to fragmentation. This is equal to the + // difference between the size of the allocation and the size of the maximal + // live set. + int64 fragmentation_bytes() const { return fragmentation_bytes_; } + bool operator==(const BufferAllocation& other) const { return index_ == other.index_; } @@ -257,6 +288,9 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. tensorflow::gtl::FlatMap assigned_buffers_; + + int64 fragmentation_bytes_ = 0; + std::vector heap_traces_; }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. @@ -441,7 +475,6 @@ class BufferAssignment { LogicalBuffer::AlignmentFunction color_alignment_; Stats stats_; - std::vector heap_simulator_traces_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index cd73654b8f666c4b96c000235cc3ad2cd0a46c17..513a8785bbd52b0a3bfa3642bbfc62b1035ffb17 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -37,14 +37,16 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" namespace xla { - namespace { +using ::testing::UnorderedElementsAre; + // DFS visitor that collects the instructions referenced by a computation // without descending into nested computations, i.e., only from the operands. class InstructionListVisitor : public DfsHloVisitorWithDefault { @@ -101,6 +103,22 @@ class BufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr RunBufferAssignmentWithInstructionSequence( + HloModule* module, + tensorflow::gtl::ArraySlice instruction_sequence, + int64 alignment = 1) { + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[module->entry_computation()] = + std::vector(instruction_sequence.begin(), + instruction_sequence.end()); + return BufferAssigner::Run( + module, + xla::MakeUnique(module, module_sequence), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }) + .ConsumeValueOrDie(); + } + // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -1370,7 +1388,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto element_slices = assignment->GetAllSlices(select, /*index=*/{0}); EXPECT_EQ(2, element_slices.size()); EXPECT_THAT(element_slices, - ::testing::UnorderedElementsAre( + UnorderedElementsAre( assignment->GetUniqueSlice(tuple_param0, /*index=*/{0}) .ConsumeValueOrDie(), assignment->GetUniqueSlice(tuple_param1, /*index=*/{0}) @@ -1473,6 +1491,98 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { } } +TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); + builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + + // Trivially, the set of peak memory logical buffer(s) of an allocation with a + // single logical buffer should be exactly the logical buffer in that + // allocation. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + int64 peak_size; + std::vector peak_buffers; + + std::tie(peak_size, peak_buffers) = + mul_buffer.ComputePeakMemoryLogicalBuffers(); + EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(f32vec100_)); + ASSERT_EQ(peak_buffers.size(), 1); + EXPECT_EQ(peak_buffers[0]->instruction(), mul); +} + +TEST_F(BufferAssignmentTest, PeakBuffers) { + // Compute the peak liveness buffers of the following sequence: + // + // %param = ... + // %log = log(%param) + // %rev = reverse(%log) + // %neg = neg(%param) + // %concat = concat(%rev, %neg) + // ROOT %root = slice(concat) + // + // In the temporary block, the set of live buffers at peak memory use should + // be {%rev, %neg, %concat}. This occurs right at the concat itself. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec100_, "")); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); + auto rev = builder.AddInstruction( + HloInstruction::CreateReverse(f32vec100_, log, {0})); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param)); + const Shape concat_shape = ShapeUtil::MakeShape(F32, {200}); + auto concat = builder.AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0)); + // Make the root tiny so no interior nodes can share its buffer. + auto root = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignmentWithInstructionSequence( + module.get(), {param, log, rev, neg, concat, root}); + + // The temporary buffer should hold the 4 interior instructions. + const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); + EXPECT_FALSE(buffer.IsInputOrOutput()); + EXPECT_TRUE(buffer.IsPreallocatedTempBuffer()); + ASSERT_EQ(buffer.assigned_buffers().size(), 4); + + int64 peak_size; + std::vector peak_buffers; + std::tie(peak_size, peak_buffers) = buffer.ComputePeakMemoryLogicalBuffers(); + + // The peak live set should be concat and its inputs. + EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {400}))); + ASSERT_EQ(peak_buffers.size(), 3); + std::vector peak_instructions; + for (const LogicalBuffer* logical_buffer : peak_buffers) { + peak_instructions.push_back(logical_buffer->instruction()); + } + EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1587,6 +1697,81 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); } +// Tests that two colocated buffer sets are not merged if an entry parameter +// buffer belongs to either of the colocation sets (b/73267882). +// +// %param --> %while.0 --> %mul --> %while.1 --> %broadcast +// +// %while.0 body just forwards the init value, so the loop carried variable +// remains the constant, whereas %while.1 changes the loop carried variable. +TEST_F(WhileBufferAssignmentTest, ColocatedBufferWithEntryParameter) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + const char* module_str = R"( +HloModule test_module + +%cond.v0 { + %param = s32[] parameter(0) + ROOT %constant = pred[] constant(true) +} + +%cond.v1 { + %param.0 = s32[] parameter(0) + ROOT %constant.0 = pred[] constant(true) +} + +%body.v0 { + ROOT %param.1 = s32[] parameter(0) +} + +%body.v1 { + %param.2 = s32[] parameter(0) + ROOT add = s32[] add(%param.2, %param.2) +} + +ENTRY %test_module { + %param.3 = s32[] parameter(0) + %while.0 = s32[] while(%param.3), condition=%cond.v0, body=%body.v0 + %mul = s32[] multiply(%while.0, %while.0) + %while.1 = s32[] while(%mul), condition=%cond.v1, body=%body.v1 + ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Get the instructions in the module. + const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* param = + module->entry_computation()->parameter_instruction(0); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + const HloInstruction* while1 = bcast->operand(0); + ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); + const HloInstruction* while0 = while1->operand(0)->operand(0); + ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); + + // Run buffer assignment. + auto assignment = RunBufferAssignment(module.get()); + TF_ASSERT_OK_AND_ASSIGN(auto slice_param, + assignment->GetUniqueSlice(param, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + + // The parameter slice is part of the while0's colocation set (init value), + // but not merged into the while1's colocation set. + EXPECT_EQ(slice_param, slice_while0); + EXPECT_NE(slice_param, slice_while1); +} + // Tests that the colocated buffers for while instructions are properly assigned // during buffer assignment such that the result tuple elements are not assigned // to the same buffer. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c..a8053d15e124319c5c898f0034b9aaa95a007a89 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -51,8 +51,8 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { return out; } -CallContext GetInstructionCallContext(const HloInstruction* instruction) { - switch (instruction->opcode()) { +CallContext GetInstructionCallContext(HloOpcode opcode) { + switch (opcode) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kWhile: @@ -101,7 +101,7 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) { void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) { CHECK_EQ(instruction->parent(), computation()); - const CallContext context = GetInstructionCallContext(instruction); + const CallContext context = GetInstructionCallContext(instruction->opcode()); if (!instruction->called_computations().empty()) { CHECK(context == CallContext::kSequential || context == CallContext::kParallel); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 688c4085dfb4f47d3e08a4abee5e7b645f595b11..97d3811508adee1bf2d0942bcc69e3e34a41c8c3 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -53,7 +53,7 @@ enum class CallContext { string CallContextToString(CallContext context); std::ostream& operator<<(std::ostream& out, const CallContext& context); -CallContext GetInstructionCallContext(const HloInstruction* instruction); +CallContext GetInstructionCallContext(HloOpcode opcode); // Represents an HLO instruction which calls one or more computations. class CallSite { diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index dab73596e1639eed62151197048ee8d29570b20a..c83da9eddc8f8b156dd9acfc99b393bf844575da 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -72,8 +72,7 @@ CompileOnlyService::CompileAheadOfTime( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); - // TODO(b/63773457): Track DebugOptions in AotCompilationOptions. - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + const DebugOptions& debug_options = options.debug_options(); // Dump computation proto state if flag is set. const string& directory_path = debug_options.xla_dump_computations_to(); @@ -101,7 +100,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, *user_computation)); + &execution_options, user_computation)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1..0392d4af48a040c4a648f7bf9bf21a62ce03a990 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -86,4 +86,7 @@ Compiler::GetPlatformCompilers() { return compilers->at(platform->id()).get(); } +AotCompilationOptions::AotCompilationOptions() + : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 74fd24edf88d44b2dfdc87556b0af43987e69e08..b4b53ae2ed425a48de5bcb6ba5c37b5d37e1f371 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -79,11 +79,15 @@ class AotCompilationOptions { device_allocator_ = device_allocator; } + const DebugOptions& debug_options() const { return debug_options_; } + DebugOptions* mutable_debug_options() { return &debug_options_; } + protected: - AotCompilationOptions() = default; + AotCompilationOptions(); private: DeviceMemoryAllocator* device_allocator_ = nullptr; + DebugOptions debug_options_; }; // Abstract compiler interface that is subclassed for compilation on a @@ -123,7 +127,7 @@ class Compiler { // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses - // prior to calling this method because the some HLO passes are required for + // prior to calling this method because some HLO passes are required for // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..f35de080853f7ec986565cb2df1050946ac3f244 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +// Tries to replace a conditional with a call operation of the corresponding +// computation. If the given conditional has a constant predicate, tries to +// replace it with a call to its true/false computation as appropirate and then +// inline that computation. +// +// Returns true if it made a change to the graph. +static StatusOr TryRemoveConditional(HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + // Do not remove conditionals that contain side-effecting instructions or + // have control predecessors/successors in either true/false computation. + if (!conditional->parent()->IsRemovable(conditional) || + conditional->HasSideEffect()) { + VLOG(2) << "Not attempting to remove conditional as it is not removable or " + "has side effect: " + << conditional->ToShortString(); + return false; + } + + if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { + VLOG(2) << "Not attempting to remove conditional as its predicate is not a " + "compile-time constant: " + << conditional->ToShortString(); + return false; + } + + auto computation = conditional->parent(); + HloInstruction* call_op; + if (conditional->operand(0)->literal().Get({})) { + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(1)}, + conditional->true_computation())); + } else { + call_op = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(2)}, + conditional->false_computation())); + } + + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + + return true; +} + +StatusOr ConditionalSimplifier::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "ConditionalSimplifier::Run(), before:\n" + module->ToString()); + bool changed = false; + + // Gather all the conditional ops in our module. We do this ahead of time so + // we don't have to worry about mutating the lists of computations or + // instructions as we iterate. + std::vector conditional_ops; + for (auto* comp : module->computations()) { + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); + changed |= result; + } + + XLA_VLOG_LINES(3, + "ConditionalSimplifier::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..063261e26d06e21a297e8e3c405898a17221b7ca --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// HLO pass that removes kConditional with a constant predicate, replacing them +// with their true or false computation as appropriate. +class ConditionalSimplifier : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "simplify-conditional"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..868348547d9f5cbdc7576c7fc0697d72c3a3e557 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ConditionalSimplifierTest : public HloVerifiedTestBase { + public: + // Makes a computation that contains a conditional with constant predicate. + HloComputation* MakeConditional(HloModule* module); +}; + +HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { + HloComputation::Builder builder(TestName()); + + // true_computation returns param+1. + HloComputation* true_computation; + { + HloComputation::Builder true_computation_builder(TestName() + + ".true_computation"); + auto param = + true_computation_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param")); + auto one = true_computation_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + + true_computation_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one)); + + true_computation = + module->AddEmbeddedComputation(true_computation_builder.Build()); + } + + // false_computation returns param+42. + HloComputation* false_computation; + { + HloComputation::Builder false_computation_builder(TestName() + + ".false_computation"); + auto param = false_computation_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), + "param")); + auto forty_two = false_computation_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + false_computation_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two)); + false_computation = + module->AddEmbeddedComputation(false_computation_builder.Build()); + } + + auto false_instrn = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "false_param")); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + + builder.AddInstruction(HloInstruction::CreateConditional( + ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation, + false_param, false_computation)); + + return module->AddEntryComputation(builder.Build()); +} + +TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { + HloComputation* computation = MakeConditional(&module()); + ASSERT_TRUE(ConditionalSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Constant())); +} + +TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { + HloComputation* computation = MakeConditional(&module()); + + auto* true_op = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + TF_ASSERT_OK( + true_op->AddControlDependencyTo(computation->root_instruction())); + + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + + auto* true_computation = conditional->true_computation(); + auto* send = true_computation->AddInstruction(HloInstruction::CreateSend( + true_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))), + /*channel_id=*/0)); + true_computation->AddInstruction(HloInstruction::CreateSendDone(send)); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + + auto* true_computation = conditional->true_computation(); + auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv( + ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); + true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv)); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { + HloComputation* computation = MakeConditional(&module()); + auto* conditional = computation->root_instruction(); + ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional); + auto* false_computation = conditional->false_computation(); + false_computation->AddInstruction( + HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cd983bc03e993caed883916de01d75dffdbc4bab..40519ecc799c8f0343294ad88009820dbd8535e9 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -58,6 +58,46 @@ bool ValueIsReadOnly(const HloValue& value) { return IsConstantValue(value) || IsEntryParameterValue(value); } +// Data structure describing the action which should be taken on parts of a +// computation buffers, with respect to the adding of special case copies. +struct SpecialCaseCopyPolicy { + // Insert a copy if the same buffer is found at multiple indices within the + // output tuple. + bool copy_root_replicated_buffers = false; + // If true, insert a copy if a buffer coming from a constant or a parameter + // is found wihtin the output tuple. + bool copy_parameters_and_constants = false; +}; + +SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, + HloModule* module, + HloComputation* computation) { + SpecialCaseCopyPolicy policy; + if (computation == module->entry_computation()) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + for (const CallSite& site : node.caller_callsites()) { + // The AddCopiesForConditional() already adds copies, but the copy remover + // removes them, so we re-add them by returning the policy here. But really + // the copy remover should not be removing them. + if (site.instruction()->opcode() == HloOpcode::kConditional) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + } + return policy; +} + +bool ShouldCopyRootValue(const HloValue& value, + const SpecialCaseCopyPolicy& policy) { + if (policy.copy_parameters_and_constants) { + return IsConstantValue(value) || + value.defining_instruction()->opcode() == HloOpcode::kParameter; + } + return false; +} + // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in // 'indices_to_copy'. Add control edges from the respective kCopy instructions // in deep copy of 'from' to the respective kCopy instruction in the deep copy @@ -282,6 +322,29 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// We add copies for all the indices of the true and false computaiton roots, +// in order to resolve interference. We later rely on the CopyRemover to drop +// the unnecessary ones. +Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, + HloInstruction* conditional) { + VLOG(2) << "Adding copies for kConditional instruction " + << conditional->name(); + TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); + + for (HloComputation* computation : + {conditional->true_computation(), conditional->false_computation()}) { + HloInstruction* root = computation->root_instruction(); + std::vector users = root->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + computation->DeepCopyInstruction(root)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); + } + computation->set_root_instruction(deep_copy); + } + return Status::OK(); +} + // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -309,6 +372,9 @@ Status AddCopiesToResolveInterference(HloModule* module) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR( + AddCopiesForConditional(*alias_analysis, instruction)); } } } @@ -557,6 +623,7 @@ class CopyRemover { auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) { + VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value; if (LiveRangeBefore(a, b)) { VLOG(2) << " Live range of " << a.value->ToShortString() << " is before " << b.value->ToShortString(); @@ -571,7 +638,7 @@ class CopyRemover { VLOG(3) << copy->name() << " copies value " << src->value->ToShortString(); VLOG(3) << "Source buffer values: " << ValueListToString(src); - VLOG(3) << "Dest buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(dest); // A kCopy instruction copies an HLO value from a source buffer and // defines an HLO value in a destination buffer. Most generally, the @@ -729,7 +796,8 @@ class CopyRemover { // has a different operand (the operand of the elided copy). for (const HloUse* copy_use : copy_value_node->uses) { operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy) { + if (copy_use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, copy_use->instruction)) { copy_map_.at(copy_use->instruction).src = operand_node; } } @@ -746,16 +814,16 @@ class CopyRemover { // updated as copies are removed. bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { if (a.uses.empty()) { - VLOG(2) << "Empty uses"; + VLOG(2) << "Empty uses for " << *a.value; return ordering_.IsDefinedBefore(*a.value, *b.value); } for (const HloUse* use : a.uses) { - VLOG(2) << "use: " << *use; - VLOG(2) << "is before:" << *b.value; + VLOG(2) << "Checking use " << *use << " against " << *b.value; if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Not before"; + VLOG(2) << "Use " << *use << " is NOT before " << *b.value; return false; } + VLOG(2) << "Use " << *use << " is before " << *b.value; } return true; } @@ -891,7 +959,6 @@ Status RemoveUnnecessaryCopies( CopyRemover copy_remover(*alias_analysis, ordering, module); XLA_VLOG_LINES(3, copy_remover.ToString()); - tensorflow::gtl::FlatSet existing_copies; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCopy && @@ -900,7 +967,6 @@ Status RemoveUnnecessaryCopies( } } } - return Status::OK(); } @@ -920,7 +986,7 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { // Identify which shape indices of which instructions need to be copied. Store // these results in 'instructions_to_copy'. - std::unordered_map> instructions_to_copy; + HloInstructionMap> instructions_to_copy; auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, const ShapeIndex& index) { auto it = instructions_to_copy.find(instruction); @@ -956,7 +1022,8 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { } TF_RET_CHECK(node.context() == CallContext::kSequential); - const bool is_entry = computation == module->entry_computation(); + SpecialCaseCopyPolicy policy = + GetSpecialCaseCopyPolicy(node, module, computation); HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. @@ -969,27 +1036,26 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { for (const HloBuffer* buffer : buffers_at_index) { buffer_seen_before |= !seen.insert(buffer).second; } - if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { - VLOG(2) << "Index " << index << " of root of computation " + if (buffers_at_index.size() > 1 || + (buffer_seen_before && policy.copy_root_replicated_buffers)) { + VLOG(2) << "Index " << index << " of computation " << computation->name() << " (" << root->name() << ") has ambiguous or non-distinct buffer. Copying."; add_index_to_copy(root, index); } }); - // For entry instructions, mark any parameter or constant values. - if (is_entry) { - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (ValueIsReadOnly(*value)) { - VLOG(2) << "Root of entry computation (" << root->name() - << ") has constant or entry parameter value at index " - << index << ". Copying."; - add_index_to_copy(root, index); - } + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ShouldCopyRootValue(*value, policy)) { + VLOG(2) << "Root of (" << root->name() << ") of computation(" + << computation->name() + << ") has constant or parameter value at index " << index + << ". Copying."; + add_index_to_copy(root, index); } } } @@ -1011,7 +1077,6 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { instruction->parent()->set_root_instruction(deep_copy); } } - return Status::OK(); } @@ -1155,7 +1220,7 @@ bool IsWhileBody(const HloComputation* computation, HloModule* module) { std::unique_ptr call_graph = CallGraph::Build(module); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); bool changed = false; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index c13a0b1cdf0b5be0b69db98b2b9587f30ca4c304..246b80286189286dd29a306dd0bda495df9dad3e 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -18,6 +18,10 @@ load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") +load( + "//third_party/mkl:build_defs.bzl", + "if_mkl", +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -105,9 +109,11 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -163,10 +169,12 @@ cc_library( ":disassembler", ":external_constant_pool", ":orc_jit_memory_mapper", + ":runtime_fp16", ":runtime_conv2d", ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", + ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", @@ -182,6 +190,20 @@ cc_library( ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) +cc_library( + name = "runtime_fp16", + srcs = [ + "runtime_fp16.cc", + ], + hdrs = [ + "runtime_fp16.h", + ], + copts = runtime_copts(), + deps = [ + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "cpu_executable", srcs = ["cpu_executable.cc"], @@ -499,7 +521,6 @@ cc_library( cc_library( name = "runtime_matvec", - srcs = ["runtime_matvec.cc"], hdrs = ["runtime_matvec.h"], copts = runtime_copts(), deps = [ @@ -522,6 +543,22 @@ cc_library( ], ) +cc_library( + name = "runtime_matmul_mkl", + srcs = ["runtime_matmul_mkl.cc"], + hdrs = ["runtime_matmul_mkl.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ] + if_mkl([ + "//third_party/mkl:intel_binary_blob", + "@mkl_dnn", + ]), +) + cc_library( name = "runtime_single_threaded_conv2d", srcs = [ @@ -568,10 +605,12 @@ cc_library( tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], + shard_count = 10, tags = ["optonly"], deps = [ ":cpu_runtime", ":runtime_matmul", + ":runtime_matmul_mkl", ":runtime_single_threaded_matmul", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:types", @@ -654,6 +693,22 @@ cc_library( ], ) +tf_cc_test( + name = "ir_emission_utils_test", + srcs = ["ir_emission_utils_test.cc"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "cpu_layout_assignment", srcs = ["cpu_layout_assignment.cc"], @@ -756,6 +811,31 @@ cc_library( ], ) +tf_cc_test( + name = "parallel_task_assignment_test", + srcs = ["parallel_task_assignment_test.cc"], + deps = [ + ":cpu_executable", + ":parallel_task_assignment", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], @@ -859,17 +939,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index ed290fcdf8bb69f1bbad57fa5a0926376bc9405a..6a7eb85e3baec3517b8f3ddef6a8dcfae9c9e614 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -25,11 +25,11 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/MC/MCContext.h" #include "llvm/Object/ObjectFile.h" +#include "llvm/Support/SmallVectorMemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" @@ -93,8 +93,8 @@ class FilteredPassManager : public llvm::legacy::PassManager { }; } // anonymous namespace -llvm::object::OwningBinary CompilerFunctor:: -operator()(llvm::Module& module) const { +std::unique_ptr CompilerFunctor::operator()( + llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); FilteredFunctionPassManager function_passes(&module, disable_expensive_passes_); @@ -157,27 +157,8 @@ operator()(llvm::Module& module) const { codegen_passes.run(module); // Construct ObjectFile from machine code buffer. - std::unique_ptr memory_buffer( - new llvm::ObjectMemoryBuffer(std::move(stream_buffer))); - llvm::Expected> - object_file_or_error = llvm::object::ObjectFile::createObjectFile( - memory_buffer->getMemBufferRef()); - CHECK(object_file_or_error); - - std::unique_ptr object_file = - std::move(object_file_or_error.get()); - if (VLOG_IS_ON(2)) { - StatusOr disassembly_status = - disassembler_->DisassembleObjectFile(*object_file); - if (disassembly_status.ok()) { - auto result = disassembly_status.ValueOrDie(); - XLA_VLOG_LINES(2, result.text); - VLOG(2) << "compiled code size: " << result.code_size_bytes << " bytes"; - } - } - - return llvm::object::OwningBinary( - std::move(object_file), std::move(memory_buffer)); + return std::unique_ptr( + new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 1a8283a702223a7414c1ffcd99c1ac42c04ac068..c38b896c5019b48fd2a16a51abd59e12ebdb29eb 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -47,7 +47,7 @@ class CompilerFunctor { post_optimization_hook_(post_optimization_hook) {} // Compile a Module to an ObjectFile. - llvm::object::OwningBinary operator()( + std::unique_ptr operator()( llvm::Module& module) const; // NOLINT private: diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index f9cc9651846cca7bd6ab7e9e61590cec4e2400da..e43777c5e5e8afcf08e1e334c8847f6b94d0d047 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -47,6 +47,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" @@ -66,6 +67,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -260,6 +262,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); + pipeline.AddPass(); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -275,6 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); } pipeline.AddPass( [](const HloInstruction& dot, @@ -314,7 +318,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Note this is not run for AOT because it would bring in thread pool // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). - // TODO(29630486) Support multi-threaded AOT. + // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); } @@ -889,11 +893,10 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, module->config().debug_options().xla_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); - llvm::object::OwningBinary object_file = + std::unique_ptr object_file = compiler_functor(llvm_module); - llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); - ObjectFileData object_file_data(object_file_data_ref.begin(), - object_file_data_ref.end()); + ObjectFileData object_file_data(object_file->getBufferStart(), + object_file->getBufferEnd()); BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 802d0a6fb46890b31d14b1fbf3b2e7d6520caccc..c053703c3524a47ee1de9681c1b986edbf109430 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -63,7 +63,7 @@ CpuExecutable::CpuExecutable( assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. - llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); + llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name); // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. CHECK(sym) << "Symbol " << entry_function_name << " not found."; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 267b89a10b3c038dc2048f0ad5b5b343c88ef0f9..d3502b3a03e27c8f90ed74c4d826dfab1c4e8b75 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -71,11 +71,6 @@ class CpuExecutable : public Executable { ir_module_string_ = ir_module_string; } - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on CPU executable. - return Unimplemented("Equality test on CPU executable is not implemented."); - } - static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 482e04052d5a914eab0e5bff2c7a83f3b698052f..0fc5a746bbbc7685ff5d4647111a750e7d7b1c19 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -30,7 +30,6 @@ bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // - hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 595c3f55b321f47e2312b93e0c238c7637495d77..6ed1cd31b18f6360bdd7fd41bd5be2e657b310a5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -77,7 +77,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { EXPECT_THAT(computation->root_instruction(), op::Fusion()); } -TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { +TEST_F(InstructionFusionTest, DotOperationNoFusion_Bitcast) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 512, 2, 128}), "arg0")); @@ -94,8 +94,7 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(dot, computation->root_instruction()); - EXPECT_TRUE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Fusion()); + EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); } TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { @@ -244,35 +243,33 @@ class OpcodeFusionTest : public InstructionFusionTest { } }; -TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) { +TEST_F(OpcodeFusionTest, Exponential_Reshape_Negate) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {1, 4}); Shape result_shape = ShapeUtil::MakeShape(F32, {4}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); - // InstructionFusion::ShouldFuse() precludes fusing a bitcast whose operand - // is a parameter, so create an operand between the parameter and bitcast. HloInstruction* exp1 = builder.AddInstruction( HloInstruction::CreateUnary(param_shape, HloOpcode::kExp, param0)); - HloInstruction* bitcast2 = builder.AddInstruction( - HloInstruction::CreateUnary(result_shape, HloOpcode::kBitcast, exp1)); + HloInstruction* reshape2 = + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, exp1)); builder.AddInstruction( - HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, bitcast2)); + HloInstruction::CreateUnary(result_shape, HloOpcode::kNegate, reshape2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); RunFusionAndCheckOpcodesWereFused( - module.get(), {HloOpcode::kNegate, HloOpcode::kBitcast, HloOpcode::kExp, + module.get(), {HloOpcode::kNegate, HloOpcode::kReshape, HloOpcode::kExp, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { +TEST_F(OpcodeFusionTest, Broadcast_Reshape_DynamicSlice_Tanh) { HloComputation::Builder builder(TestName()); Shape param_shape = ShapeUtil::MakeShape(F32, {8}); Shape starts_shape = ShapeUtil::MakeShape(F32, {2}); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {1, 8, 8}); - Shape bitcast_shape = ShapeUtil::MakeShape(F32, {8, 8}); + Shape reshape_shape = ShapeUtil::MakeShape(F32, {8, 8}); Shape dynamic_slice_shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "param")); @@ -280,11 +277,11 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { HloInstruction::CreateParameter(1, starts_shape, "starts")); HloInstruction* broadcast2 = builder.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, param0, {1})); - HloInstruction* bitcast3 = builder.AddInstruction(HloInstruction::CreateUnary( - bitcast_shape, HloOpcode::kBitcast, broadcast2)); + HloInstruction* reshape3 = builder.AddInstruction( + HloInstruction::CreateReshape(reshape_shape, broadcast2)); HloInstruction* dynamic_slice4 = builder.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, bitcast3, param1, {4, 4})); + dynamic_slice_shape, reshape3, param1, {4, 4})); builder.AddInstruction(HloInstruction::CreateUnary( dynamic_slice_shape, HloOpcode::kTanh, dynamic_slice4)); @@ -293,7 +290,7 @@ TEST_F(OpcodeFusionTest, Broadcast_Bitcast_DynamicSlice_Tanh) { RunFusionAndCheckOpcodesWereFused( module.get(), - {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kBitcast, + {HloOpcode::kTanh, HloOpcode::kDynamicSlice, HloOpcode::kReshape, HloOpcode::kBroadcast, HloOpcode::kParameter, HloOpcode::kParameter}); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 40ace963270e8cead47cc731cc326351178dff7d..872b0be1f8a8ec317bf059fd1c4d2550e2ad161a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -31,15 +31,27 @@ XfeedManager* GetXfeedManager() { return manager; } +extern const char* const kEigenMatMulF16SymbolName = + "__xla_cpu_runtime_EigenMatMulF16"; extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kMKLMatMulF32SymbolName = + "__xla_cpu_runtime_MKLMatMulF32"; +extern const char* const kMKLMatMulF64SymbolName = + "__xla_cpu_runtime_MKLMatMulF64"; +extern const char* const kMKLSingleThreadedMatMulF32SymbolName = + "__xla_cpu_runtime_MKLSingleThreadedMatMulF32"; +extern const char* const kMKLSingleThreadedMatMulF64SymbolName = + "__xla_cpu_runtime_MKLSingleThreadedMatMulF64"; extern const char* const kEigenConvF16SymbolName = "__xla_cpu_runtime_EigenConvF16"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedMatMulF16SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 2141dfe1cedd6f9674acc348152574b4fd30895b..e392e231b4c71b2e206640a47b712de70a148582 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -41,11 +41,17 @@ namespace runtime { // the actual symbol. // 2. When using ahead-of-time compilation, the linker can resolve the name // because it is a symbol in the cpu_runtime library. +extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kMKLMatMulF32SymbolName; +extern const char* const kMKLMatMulF64SymbolName; +extern const char* const kMKLSingleThreadedMatMulF32SymbolName; +extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; extern const char* const kEigenSingleThreadedConvF16SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index f385829cdf5cafbd35e083f47106734cdd5dde88..2ac950e6d93ade315808f2ca1d0bdd7bc85f53b9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" @@ -130,25 +131,23 @@ MatMulShape MatMulShapes[] = { // * transpose_lhs // * transpose_rhs // * single_threaded -using EigenMatMulTestParam = std::tuple; +using MatMulTestParam = std::tuple; -class EigenMatMulTest - : public CpuRuntimeTest, - public ::testing::WithParamInterface { +class EigenMatMulTest : public CpuRuntimeTest, + public ::testing::WithParamInterface { public: - static string Name( - const ::testing::TestParamInfo& info) { + static string Name(const ::testing::TestParamInfo& info) { MatMulShape shape = std::get<0>(info.param); bool transpose_lhs = std::get<1>(info.param); bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); return tensorflow::strings::Printf( - "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, + "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", single_threaded ? "single" : "multi"); } -}; // namespace xla +}; TEST_P(EigenMatMulTest, DoIt) { MatMulShape shape = std::get<0>(GetParam()); @@ -169,5 +168,74 @@ INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest, ::testing::Bool()), EigenMatMulTest::Name); +#ifdef INTEL_MKL +class MKLMatMulTest : public CpuRuntimeTest, + public ::testing::WithParamInterface { + public: + static string Name(const ::testing::TestParamInfo& info) { + MatMulShape shape = std::get<0>(info.param); + bool transpose_lhs = std::get<1>(info.param); + bool transpose_rhs = std::get<2>(info.param); + bool single_threaded = std::get<3>(info.param); + + return tensorflow::strings::Printf( + "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, + transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); + } +}; + +std::unique_ptr> MKLMatrixMultiply(const Array2D& a, + const Array2D& b, + bool transpose_lhs, + bool transpose_rhs, + bool single_threaded) { + CHECK_EQ(a.width(), b.height()); + int64 m = a.height(); + int64 n = b.width(); + int64 k = a.width(); + + // The MKL matmul runtime function expects the matrix to be in column major + // order and array2d is in row-major order. Create transposes of a and b. The + // 'data' buffer in the transposed array is the original array in column major + // order. + auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs); + auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs); + + // Since we're going to transpose c before returning it, swap the order of the + // dimension sizes to ensure the returned array is properly dimensioned. + auto c_transpose = MakeUnique>(n, m); + if (single_threaded) { + __xla_cpu_runtime_MKLSingleThreadedMatMulF32( + nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), + m, n, k, transpose_lhs, transpose_rhs); + } else { + __xla_cpu_runtime_MKLMatMulF32(nullptr, c_transpose->data(), + a_transpose->data(), b_transpose->data(), m, + n, k, transpose_lhs, transpose_rhs); + } + return MaybeTransposeArray2D(*c_transpose, true); +} + +TEST_P(MKLMatMulTest, DoIt) { + MatMulShape shape = std::get<0>(GetParam()); + bool transpose_lhs = std::get<1>(GetParam()); + bool transpose_rhs = std::get<2>(GetParam()); + bool single_threaded = std::get<3>(GetParam()); + + auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k); + auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n); + auto c = + MKLMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs, single_threaded); + CheckMatrixMultiply(*a, *b, *c); +} + +INSTANTIATE_TEST_CASE_P(MKLMatMulTestInstantiaion, MKLMatMulTest, + ::testing::Combine(::testing::ValuesIn(MatMulShapes), + ::testing::Bool(), ::testing::Bool(), + ::testing::Bool()), + MKLMatMulTest::Name); +#endif // INTEL_MKL + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index cfe7c9c3af0be109ac8a86753e880e2bcbceba41..29afd8ea5f9822ea9ae969ae035511a58de4888e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -715,6 +715,11 @@ tensorflow::Status DotOpEmitter::Emit() { // which performs the sum-of-products (the reduction loop) before storing // the result in the output buffer. + // This routine assumes that the dot operation is not in a parallelized + // enclosing computation. + CHECK( + dot_.parent()->root_instruction()->outer_dimension_partitions().empty()); + const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); @@ -913,22 +918,35 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // The two transpose_... parameters are actually booleans, but we use int32 // to avoid target-dependent calling convention details. - bool multi_threaded_eigen = + bool multi_threaded = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); PrimitiveType type = target_array_.GetShape().element_type(); llvm::Type* float_type; const char* fn_name; switch (type) { + case F16: + fn_name = multi_threaded + ? runtime::kEigenMatMulF16SymbolName + : runtime::kEigenSingleThreadedMatMulF16SymbolName; + float_type = ir_builder_->getHalfTy(); + break; case F32: - fn_name = multi_threaded_eigen - ? runtime::kEigenMatMulF32SymbolName - : runtime::kEigenSingleThreadedMatMulF32SymbolName; + fn_name = multi_threaded + ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName + : runtime::kEigenMatMulF32SymbolName) + : (use_mkl_dnn + ? runtime::kMKLSingleThreadedMatMulF32SymbolName + : runtime::kEigenSingleThreadedMatMulF32SymbolName); float_type = ir_builder_->getFloatTy(); break; case F64: - fn_name = multi_threaded_eigen - ? runtime::kEigenMatMulF64SymbolName - : runtime::kEigenSingleThreadedMatMulF64SymbolName; + fn_name = multi_threaded + ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName + : runtime::kEigenMatMulF64SymbolName) + : (use_mkl_dnn + ? runtime::kMKLSingleThreadedMatMulF64SymbolName + : runtime::kEigenSingleThreadedMatMulF64SymbolName); float_type = ir_builder_->getDoubleTy(); break; default: @@ -1051,7 +1069,8 @@ static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. - return output_shape.element_type() == F32 && + PrimitiveType output_primitive_type = output_shape.element_type(); + return (output_primitive_type == F32 || output_primitive_type == F16) && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape); } diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 99c5e16db70c6a203b4751c1ed8a106c0dc260e6..e97113dfa0f59e791d614c0093d0781e49c48ee4 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -115,7 +115,7 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( for (int i = 0; i < hlo->operand_count(); i++) { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(i))( - ElementwiseSourceIndex(index, *hlo, 0))); + ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 788217aab6172b4e548452b3f6ffd4197c163ce4..f209a69e3cd0f8d336d61bafd1e22be8bc88ca3f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -34,14 +34,16 @@ bool PotentiallyImplementedAsEigenConvolution( // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); - const Shape& kernel_shape = convolution.operand(0)->shape(); + const Shape& kernel_shape = convolution.operand(1)->shape(); if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // Make sure input and kernel has the same data type. + CHECK( + ShapeUtil::SameElementTypeIgnoringFpPrecision(input_shape, kernel_shape)); // TODO(b/65408531): Explore using Eigen dot for complex64 type. - if (ShapeUtil::ElementIsComplex(input_shape) || - ShapeUtil::ElementIsComplex(kernel_shape)) { + if (ShapeUtil::ElementIsComplex(input_shape)) { return false; } if (window_util::HasWindowReversal(convolution.window())) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..215f48c4cc1a1a6b13d98dff76e0d1f0f773f5c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +TEST(IrEmitterTest, ConvWithZeroSizedKernelNotImplementedAsEigen) { + const char* const hlo_string = R"( +HloModule ModuleWithConv + +ENTRY Conv { + input = f32[32,50,28,28]{3,2,1,0} parameter(0) + kernel = f32[0,32,5,5]{3,2,1,0} parameter(1) + ROOT convolution = f32[64,50,24,24]{3,2,1,0} convolution(input, kernel), + window={size=5x5}, + dim_labels=b01f_01io->b01f +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloComputation* entry_computation = module->entry_computation(); + + HloInstruction* conv_instr = entry_computation->root_instruction(); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 4dffaee87f6b33933b58c8c58478eec918569197..3405277d449f2d9e558f2d3f83277163655af592 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -438,12 +438,14 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer, - length_32, 1); + ir_builder_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, + acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address, - length_32, 1); + ir_builder_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, + program_buffer_address, + /*SrcAlign=*/1, length_32); } ir_builder_.CreateCall(release_func, @@ -2074,7 +2076,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32})); + /*supported_types=*/{F16, F32})); llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -2441,7 +2443,8 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { auto* memcpy_instruction = ir_builder_.CreateMemCpy( - target, source, element_count * primitive_type_size, element_alignment); + target, /*DstAlign=*/element_alignment, source, + /*SrcAlign=*/element_alignment, element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. @@ -2905,7 +2908,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1); + ir_builder_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index cd997f07890cdc1d9a546ede58cc1d992b6416ae..07a9f0efcb64db4b2ff0c6518d4b48eee9a505e0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -394,7 +394,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( for (auto& entry : *function_names_) { tensorflow::mutex_lock lock(jit_mutex_); HloInstruction* instruction = entry.first; - llvm::JITSymbol sym = jit_->FindSymbol(entry.second); + llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry.second); TF_RET_CHECK(sym); InsertOrDie( &functions, instruction, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index c393e9b8ea39bfb4c605ebba8e2cd29726bc4af9..87c0a3df458eb4b3f217192597e0de1576304367 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -83,12 +83,6 @@ class ParallelCpuExecutable : public Executable { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on CPU parallel executable. - return Unimplemented( - "Equality test on CPU parallel executable is not implemented."); - } - private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 1e439cde11cf74272101b80c867a308e51ab26a6..54af40506dab48b3c2a3a44eb0b5f5fb213a32ec 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -29,7 +29,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( : LoopEmitter(target_element_generator, target_array, ir_builder), dynamic_loop_bounds_(dynamic_loop_bounds) {} -llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( +std::vector +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) { CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); @@ -69,7 +70,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); CHECK(exit_bb_ != nullptr); - return array_index; + return {array_index}; } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index ce92e36a944de33b991d97460f0b2e859ad56081..755715634aa70a822b21d25dcae20a8fe053477a 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -60,7 +60,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; - llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + std::vector EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) override; private: diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index deb21bf4ef5895cfdbec5c2449b6ce7b306a7008..fb28280fade307ac1f193e7dca481bd2afa855fc 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -71,7 +71,7 @@ class DefaultCostModel : public ParallelCostModel { if (flops_to_bytes_ratio <= 1.0) { // Limit max parallelism for I/O bound instructions by assuming a // sub-linear scaling function (fit based on empirical benchmark results). - // TODO(29630486) Develop system bandwidth model. + // TODO(b/29630486) Develop system bandwidth model. max_parallelism = std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); // Use shape size instruction cost and L2 cache size min per-thread cost. @@ -81,7 +81,7 @@ class DefaultCostModel : public ParallelCostModel { // Use max parallelism for compute bound instructions. max_parallelism = max_parallelism_; // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. + // TODO(b/29630486) Improve on this linear cost model. // Consider making 'min_cost_per_thread' be a function of the target // bandwidth limit for instructions with low arithmetic complexity. instruction_cost = @@ -128,24 +128,25 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - instruction->opcode() == HloOpcode::kGetTupleElement || - instruction->opcode() == HloOpcode::kBitcast || - instruction->opcode() == HloOpcode::kFft || - (instruction->opcode() == HloOpcode::kConvolution && + auto opcode = instruction->opcode(); + if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || + opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || + opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || + opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || + opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || + opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && + (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { return 1; } + // Consult 'cost_model_' to compute target parallel task count. return cost_model_->GetParallelTaskCount(instruction); } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..13eb75a57213b1a68a5732a4f6061efdf97fa4f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class ParallelTaskAssignmentTest : public HloVerifiedTestBase { + protected: + const HloCostAnalysis::ShapeSizeFunction shape_size_func_ = + cpu::CpuExecutable::ShapeSizeBytes; + + // Use any value larger than 2 since we only test whether a module is + // parallelized or not + const int max_parallelism_ = 10; +}; + +TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_Dot + ENTRY Dot { + dot_lhs = f32[196614,2]{1,0} parameter(0) + dot_rhs = f32[2,1]{1,0} parameter(1) + ROOT dot = f32[196614,1]{1,0} dot(dot_lhs, dot_rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, + FusedComputationWithDotOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_DotNestedInFusedComp + fused_computation.0 { + parameter.0 = f32[196614,2]{1,0} parameter(0) + parameter.0.1 = f32[2,1]{1,0} parameter(1) + parameter.0.2 = f32[196614,1]{1,0} parameter(2) + dot.0 = f32[196614,1]{1,0} dot(parameter.0, parameter.0.1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT add.0 = f32[196614,1]{1,0} add(dot.0, parameter.0.2) + + } + ENTRY DotNestedInFusedComp { + parameter = f32[196614,2]{1,0} parameter(0) + parameter.1 = f32[2,1]{1,0} parameter(1) + parameter.2 = f32[196614,1]{1,0} parameter(2) + ROOT fusion = f32[196614,1]{1,0} fusion(parameter, parameter.1, + parameter.2), kind=kOutput, calls=fused_computation.0 + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_rng + ENTRY Rng { + src0 = f32[] parameter(0) + src1 = f32[] parameter(1) + ROOT rng0 = f32[1234567,2]{1,0} rng(f32[] src0, f32[] src1), + distribution=rng_uniform + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { + const string hlo_string = R"( + HloModule TestTaskParallel_infeed_outfeed + ENTRY InfeedOutfeed { + infeed0 = u32[12345678,2]{1,0} infeed() + ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( + max_parallelism_, shape_size_func_) + .Run(&module())); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h index 39e20ed45639040110b99ddb52eb6f6dab26dfaa..7337c907f5c83d608641b7382e75902e6f6c05d4 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc b/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc new file mode 100644 index 0000000000000000000000000000000000000000..af0275c8bd00c82220fbe116eb90d2692393713b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/core/platform/macros.h" + +namespace { +using tensorflow::uint16; +using tensorflow::uint32; + +// Helper class that lets us access the underlying bit representation +// of a float without breaking C++ strict aliasing. +class AliasedFloatInt { + public: + static_assert(sizeof(float) == sizeof(uint32), ""); + + static AliasedFloatInt FromFloat(float f) { + AliasedFloatInt value; + value.set_float(f); + return value; + } + + static AliasedFloatInt FromUInt(uint32 u) { + AliasedFloatInt value; + value.set_uint(u); + return value; + } + + void set_float(float f) { memcpy(&value_, &f, sizeof(f)); } + float as_float() const { + float f; + memcpy(&f, &value_, sizeof(f)); + return f; + } + + void set_uint(uint32 u) { value_ = u; } + uint32 as_uint() const { return value_; } + + private: + uint32 value_; +}; +} // namespace + +// __gnu_f2h_ieee and __gnu_h2f_ieee are marked as weak symbols so if XLA is +// built with compiler-rt (that also defines these symbols) we don't get a +// duplicate definition linker error. Making these symbols weak also ensures +// that the compiler-rt definitions "win", but that isn't essential. + +// Algorithm copied from Eigen. +uint16 TF_ATTRIBUTE_WEAK __gnu_f2h_ieee(float float_value) { + AliasedFloatInt f = AliasedFloatInt::FromFloat(float_value); + + const AliasedFloatInt f32infty = AliasedFloatInt::FromUInt(255 << 23); + const AliasedFloatInt f16max = AliasedFloatInt::FromUInt((127 + 16) << 23); + const AliasedFloatInt denorm_magic = + AliasedFloatInt::FromUInt(((127 - 15) + (23 - 10) + 1) << 23); + unsigned int sign_mask = 0x80000000u; + uint32 o = static_cast(0x0u); + + unsigned int sign = f.as_uint() & sign_mask; + f.set_uint(f.as_uint() ^ sign); + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.as_uint() >= + f16max.as_uint()) { // result is Inf or NaN (all exponent bits set) + o = (f.as_uint() > f32infty.as_uint()) ? 0x7e00 + : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.as_uint() < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.set_float(f.as_float() + denorm_magic.as_float()); + + // and one integer subtract of the bias later, we have our final float! + o = static_cast(f.as_uint() - denorm_magic.as_uint()); + } else { + unsigned int mant_odd = + (f.as_uint() >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.set_uint(f.as_uint() + (static_cast(15 - 127) << 23) + + 0xfff); + // rounding bias part 2 + f.set_uint(f.as_uint() + mant_odd); + // take the bits! + o = static_cast(f.as_uint() >> 13); + } + } + + o |= static_cast(sign >> 16); + return o; +} + +// Algorithm copied from Eigen. +float TF_ATTRIBUTE_WEAK __gnu_h2f_ieee(uint16 h) { + const AliasedFloatInt magic = AliasedFloatInt::FromUInt(113 << 23); + const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + AliasedFloatInt o; + + o.set_uint((h & 0x7fff) << 13); // exponent/mantissa bits + unsigned int exp = shifted_exp & o.as_uint(); // just the exponent + o.set_uint(o.as_uint() + ((127 - 15) << 23)); // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.set_uint(o.as_uint() + ((128 - 16) << 23)); // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.set_uint(o.as_uint() + (1 << 23)); // extra exp adjust + o.set_float(o.as_float() - magic.as_float()); // renormalize + } + + o.set_uint(o.as_uint() | (h & 0x8000) << 16); // sign bit + return o.as_float(); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fp16.h b/tensorflow/compiler/xla/service/cpu/runtime_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..01d92d031904af99884c2583a8c7b5086b289d44 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fp16.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ + +#include "tensorflow/core/platform/types.h" + +// Converts an F32 value to a F16. +extern "C" tensorflow::uint16 __gnu_f2h_ieee(float); + +// Converts an F16 value to a F32. +extern "C" float __gnu_h2f_ieee(tensorflow::uint16); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index bff57d33ae23fbba8c664cbd18df77e4c35eb592..39b13183ff093611a42b3931d45f64eadb420622 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -63,30 +63,41 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims); } +template +void MatMulImpl(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, + int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + if (m == 1 || n == 1) { + // Despite being single threaded, this version of matrix * vector is faster. + xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } else { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); + } +} + } // namespace +void __xla_cpu_runtime_EigenMatMulF16(const void* run_options_ptr, + Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - // Despite being single threaded, this version of matrix * vector is faster. - xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - // Despite being single threaded, this version of matrix * vector is faster. - xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + MatMulImpl(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h index fdb644651dd5d0fa0345580f52ed0fb051672285..d96fe3d58bd5ffbad347e3ede3534d1d47be697a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { @@ -25,6 +26,12 @@ extern "C" { // order. 'out' is a pointer to a buffer sufficiently large to hold the result // of the operation. Following standard nomenclature: lhs is m x k, // rhs is k x n, and out is m x n. +extern void __xla_cpu_runtime_EigenMatMulF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenMatMulF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc new file mode 100644 index 0000000000000000000000000000000000000000..92da5f71c23d5e1450b39ea8b7bb8345f6fabb3b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" +#include "third_party/intel_mkl_ml/include/mkl_cblas.h" +#include "third_party/intel_mkl_ml/include/mkl_service.h" + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/types.h" + +#define EIGEN_USE_THREADS +#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool" + +using tensorflow::int32; +using tensorflow::int64; + +namespace { +// BLAS GEMM API for 32-bit Matrix Multiplication. + +// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c. +// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0. +// Matrix lhs, rhs and out are all colum-major. +void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const float alpha = 1.0f, beta = 0.0f; + // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c, + // respectively. For column-major matrices, the leading dimension is the + // stride between consecutive columns (which equals the number of rows). If + // the matrix is transposed, the leading dimension is the stride between + // consecutive rows (which equals the number of columns). + int lda = transpose_lhs ? k : m; + int ldb = transpose_rhs ? n : k; + int ldc = m; + cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans, + transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs, + lda, rhs, ldb, beta, out, ldc); +} + +// BLAS GEMM API for 64-bit Matrix Multiplication. + +// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c. +// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0. +// Matrix lhs, rhs and out are all colum-major. +void MatMulF64(const void* run_options_ptr, double* out, double* lhs, + double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const float alpha = 1.0f, beta = 0.0f; + // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c, + // respectively. For a column-major matrix, the leading dimension is the + // stride between consecutive columns (which equals the number of rows). If + // the matrix is transposed, the leading dimension is the stride between + // consecutive rows (which equals the number of columns). + int lda = transpose_lhs ? k : m; + int ldb = transpose_rhs ? n : k; + int ldc = m; + cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans, + transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs, + lda, rhs, ldb, beta, out, ldc); +} + +} // namespace + +void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out, + float* lhs, float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread + // number specified in intra_op_thread_pool to MKL. + int prev_num_threads = mkl_set_num_threads_local( + run_options->intra_op_thread_pool()->numThreads()); + MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +// BLAS GEMM API for 64-bit Matrix Multiplication +void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out, + double* lhs, double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread + // number specified in intra_op_thread_pool to MKL. + int prev_num_threads = mkl_set_num_threads_local( + run_options->intra_op_thread_pool()->numThreads()); + MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr, + float* out, float* lhs, + float* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + // Set the thread number to 1 for single threaded excution. + int prev_num_threads = mkl_set_num_threads_local(1); + MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr, + double* out, double* lhs, + double* rhs, int64 m, int64 n, + int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + // Set the thread number to 1 for single threaded excution. + int prev_num_threads = mkl_set_num_threads_local(1); + MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + // Set thread number back to the previous number. + mkl_set_num_threads_local(prev_num_threads); +} +#endif // INTEL_MKL diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h new file mode 100644 index 0000000000000000000000000000000000000000..831b796efb971f6fb0170e2321c00ac415f2830f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ + +#include +#include "tensorflow/core/platform/types.h" +#ifdef INTEL_MKL +#include "third_party/intel_mkl_ml/include/mkl_cblas.h" + +extern void __xla_cpu_runtime_MKLMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_MKLMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + +#else +extern void __xla_cpu_runtime_MKLMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} +extern void __xla_cpu_runtime_MKLMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} +extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out, + double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + std::cerr << "Attempt to call MKL MatMul runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +} + +#endif // INTEL_MKL +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc b/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc deleted file mode 100644 index 435820cdd36e2a906d9dfbe2555f4c0df623c729..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "third_party/eigen3/Eigen/Core" -#include "tensorflow/compiler/xla/service/cpu/runtime_matvec.h" - -using tensorflow::int32; -using tensorflow::int64; - -namespace { - -// Does mat * x or mat^T * x. -template -void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols, - int32 transpose) { - // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to - // be faster (b/30223679). See also: the matmul op kernel in TensorFlow, - // which implements the same optimization. - using Matrix = Eigen::Matrix; - using MatrixMap = Eigen::Map; - - using Vector = Eigen::Matrix; - using VectorMap = Eigen::Map; - - auto x = VectorMap(x_buf, cols); - auto out = VectorMap(out_buf, rows); - - int64 mat_rows = rows; - int64 mat_cols = cols; - - if (transpose) { - std::swap(mat_rows, mat_cols); - } - - auto mat = MatrixMap(mat_buf, mat_rows, mat_cols); - - if (transpose) { - out = mat.transpose() * x; - } else { - out = mat * x; - } -} - -// Converts matmul-style args to matvec. -template -void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, - int32 transpose_lhs, int32 transpose_rhs) { - // If the input is in the form x * A, where x is the vector, then bring A back - // over to the left hand side. We make use of the identity - // - // (x * A)^T = A^T * x^T - // - // We do not need to take the transpose of x or of the result since taking - // the transpose of a vector does not change the memory layout. - const int64 cols = k; - - T* mat; - T* vec; - int64 rows; - bool transpose_mat; - - bool is_mat_vec = (n == 1); - - if (is_mat_vec) { - mat = lhs; - vec = rhs; - rows = m; - transpose_mat = transpose_lhs; - } else { - mat = rhs; - vec = lhs; - rows = n; - transpose_mat = !transpose_rhs; - } - - MatVec(out, mat, vec, rows, cols, transpose_mat); -} - -} // namespace - -namespace xla { - -void EigenMatVecF32(float* out, float* lhs, float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { - assert((m == 1 || n == 1) && "not a matrix-vector multiply"); - DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -void EigenMatVecF64(double* out, double* lhs, double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { - assert((m == 1 || n == 1) && "not a matrix-vector multiply"); - DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h index 1bd8dfb377acc1f7cfbe9a92773f87f0ef25de3a..70eb98c54169824e220d9287753c0849362eade6 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h @@ -16,10 +16,86 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#include "third_party/eigen3/Eigen/Core" + #include "tensorflow/core/platform/types.h" namespace xla { +namespace detail { + +using tensorflow::int32; +using tensorflow::int64; + +// Does mat * x or mat^T * x. +template +void MatVec(T* out_buf, T* mat_buf, T* x_buf, int64 rows, int64 cols, + int32 transpose) { + // Use an Eigen Matrix instead of a Tensor, as the GEMV from Matrix seems to + // be faster (b/30223679). See also: the matmul op kernel in TensorFlow, + // which implements the same optimization. + using Matrix = Eigen::Matrix; + using MatrixMap = Eigen::Map; + + using Vector = Eigen::Matrix; + using VectorMap = Eigen::Map; + + auto x = VectorMap(x_buf, cols); + auto out = VectorMap(out_buf, rows); + + int64 mat_rows = rows; + int64 mat_cols = cols; + + if (transpose) { + std::swap(mat_rows, mat_cols); + } + + auto mat = MatrixMap(mat_buf, mat_rows, mat_cols); + + if (transpose) { + out = mat.transpose() * x; + } else { + out = mat * x; + } +} + +// Converts matmul-style args to matvec. +template +void DispatchMatVec(T* out, T* lhs, T* rhs, int64 m, int64 n, int64 k, + int32 transpose_lhs, int32 transpose_rhs) { + // If the input is in the form x * A, where x is the vector, then bring A back + // over to the left hand side. We make use of the identity + // + // (x * A)^T = A^T * x^T + // + // We do not need to take the transpose of x or of the result since taking + // the transpose of a vector does not change the memory layout. + const int64 cols = k; + + T* mat; + T* vec; + int64 rows; + bool transpose_mat; + + bool is_mat_vec = (n == 1); + + if (is_mat_vec) { + mat = lhs; + vec = rhs; + rows = m; + transpose_mat = transpose_lhs; + } else { + mat = rhs; + vec = lhs; + rows = n; + transpose_mat = !transpose_rhs; + } + + MatVec(out, mat, vec, rows, cols, transpose_mat); +} + +} // namespace detail + // Performs a matrix-vector multiplication using Eigen. 'lhs' and 'rhs' are // pointers to buffers containing input matrices in column-major order. 'out' is // a pointer to a buffer sufficiently large to hold the result of the @@ -30,15 +106,15 @@ namespace xla { // // TODO(b/64684907): Compare runtime performance of these functions with dot // simplification. -void EigenMatVecF32(float* out, float* lhs, float* rhs, tensorflow::int64 m, - tensorflow::int64 n, tensorflow::int64 k, - tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); - -void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m, - tensorflow::int64 n, tensorflow::int64 k, - tensorflow::int32 transpose_lhs, - tensorflow::int32 transpose_rhs); +template +void EigenMatVec(T* out, T* lhs, T* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + assert((m == 1 || n == 1) && "not a matrix-vector multiply"); + detail::DispatchMatVec(out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h index f216bd0152aa93b8753d881938c63a9cabea899b..44b201725b2c724f48c1a3f0373c41e76211e0c2 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index ee8eb081556d60fcf6537b1036a4a5825c4c7bf6..17303e2f0d34e531a3a56aa147608b949e0f43ae 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -57,26 +57,38 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, C = A.contract(B, dims); } +template +void SingleThreadedMatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + if (m == 1 || n == 1) { + xla::EigenMatVec(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); + } else { + MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, + transpose_rhs); + } +} + } // namespace +void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, + int32 transpose_rhs) { + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); +} + void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - xla::EigenMatVecF32(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } void __xla_cpu_runtime_EigenSingleThreadedMatMulF64( const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - if (m == 1 || n == 1) { - xla::EigenMatVecF64(out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); - } else { - MatMul(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); - } + SingleThreadedMatMul(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h index 029eb9514287d8c69cde2cfb06e0d56e78d6f165..82a1fcce594fa5b04f4fe459870991863c32a91a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" extern "C" { @@ -25,6 +26,12 @@ extern "C" { // 'out' is a pointer to a buffer sufficiently large to hold the result of the // operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and // out is m x n. +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs); + extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n, diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 61b408b8c24dded134218110d4e219c31f1685a8..42fe955f1917e0268dc739e44fbd0a7afb39185c 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -20,12 +20,13 @@ namespace cpu { std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { // Gather outer-most dims where dim_size >= 'target_partition_count'. - // Note: always leave inner-dim static for vectorization/optimizations. + // This may include the inner-dim as LLVM can vectorize loops with dynamic + // bounds. std::vector outer_dims; int64 outer_dim_size = 1; // TODO(b/27458679) Consider reserving enough minor dimensions (based on // target vector register width) to enable vector instructions. - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { const int64 dimension = shape_.layout().minor_to_major(i); outer_dims.push_back(dimension); outer_dim_size *= shape_.dimensions(dimension); diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ee0c53fa6d7c41481a53350e57e5844dea2644c1..ae80a6f4977f85cfd9f872734fd0a69432a1f382 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -30,105 +30,65 @@ class ShapePartitionAssignerTest : public HloTestBase { protected: typedef std::vector Vec; - void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + void RunR2Test(const Shape& shape, int64 max_target_partition_count, + const std::vector* expected_partitions) { ShapePartitionAssigner assigner(shape); - // Check all partitions of outer dimension. - for (int64 i = 1; i <= expected_max_partition_count; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), - assigner.Run(/*target_partition_count=*/i))); + // Iterate through 1..max_target_partition_count. + for (int64 i = 1; i <= max_target_partition_count; ++i) { + std::vector actual_partitions = + assigner.Run(/*target_partition_count=*/i); + EXPECT_THAT(actual_partitions, expected_partitions[i - 1]); } - // Check target_partition_count > outer dimension size. - EXPECT_TRUE(ContainersEqual( - Vec({expected_max_partition_count}), - assigner.Run( - /*target_partition_count=*/expected_max_partition_count + 1))); } }; TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); + std::vector expected_partitions[] = {{1} /* 1 */, {1, 2} /* 2 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); + std::vector expected_partitions[] = { + {1} /* 1 */, {1, 2} /* 2 */ + }; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 2, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); + std::vector expected_partitions[] = {{1} /* 1 */, {2} /* 2 */, + {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 6, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { - RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 4, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 5; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({4, 3}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({5, 3}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {4} /* 4 */, + {5} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {4, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {4, 3} /* 12 */, + {4, 3} /* 13 */, {4, 3} /* 14 */, {5, 3} /* 15 */, {4, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}), 16, + expected_partitions); } TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { - Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); - ShapePartitionAssigner assigner(shape); - - for (int64 i = 1; i <= 3; ++i) { - EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( - /*target_partition_count=*/i))); - } - - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); - EXPECT_TRUE( - ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); - EXPECT_TRUE( - ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/10))); - EXPECT_TRUE(ContainersEqual(Vec({3, 3}), - assigner.Run(/*target_partition_count=*/11))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/12))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/13))); - EXPECT_TRUE(ContainersEqual(Vec({3, 4}), - assigner.Run(/*target_partition_count=*/14))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/15))); - EXPECT_TRUE(ContainersEqual(Vec({3, 5}), - assigner.Run(/*target_partition_count=*/16))); + std::vector expected_partitions[] = { + {1} /* 1 */, {2} /* 2 */, {3} /* 3 */, {2, 2} /* 4 */, + {2, 2} /* 5 */, {3, 2} /* 6 */, {3, 2} /* 7 */, {3, 2} /* 8 */, + {3, 3} /* 9 */, {3, 3} /* 10 */, {3, 3} /* 11 */, {3, 4} /* 12 */, + {3, 4} /* 13 */, {3, 4} /* 14 */, {3, 5} /* 15 */, {3, 2, 2} /* 16 */}; + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}), 16, + expected_partitions); } class ShapePartitionIteratorTest : public HloTestBase { diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index f19cb86cc42f03543c77ff17b6ae4a8a69bbb140..b7ce5bbe47482320bfb9524c8f366a463b9579ed 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -33,7 +33,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" @@ -44,36 +46,6 @@ namespace xla { namespace cpu { namespace { -// A simple SymbolResolver that delegates to the host dynamic linker. -class SimpleResolver : public llvm::LegacyJITSymbolResolver { - public: - explicit SimpleResolver(ExternalConstantPool* external_constant_pool) - : external_constant_pool_(external_constant_pool) {} - - llvm::JITSymbol findSymbol(const std::string& name) override { - if (const uint8* from_constant_pool = - external_constant_pool_->Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); - if (func_addr == nullptr) { - return nullptr; - } - llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), - llvm::JITSymbolFlags::None); - return symbol_info; - } - llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { - return nullptr; - } - - private: - ExternalConstantPool* external_constant_pool_; -}; - llvm::SmallVector DetectMachineAttributes() { llvm::SmallVector result; llvm::StringMap host_features; @@ -116,35 +88,22 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - execution_session_(string_pool_), symbol_resolver_(llvm::orc::createLegacyLookupResolver( [this](const std::string& name) -> llvm::JITSymbol { - if (const uint8* from_constant_pool = - external_constant_pool_.Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); - if (func_addr == nullptr) { - return nullptr; - } - llvm::JITEvaluatedSymbol symbol_info( - reinterpret_cast(func_addr), - llvm::JITSymbolFlags::None); - return symbol_info; + return this->ResolveRuntimeSymbol(name); }, [](llvm::Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })), - object_layer_( - execution_session_, - [](llvm::orc::VModuleKey) { - return std::make_shared( - orc_jit_memory_mapper::GetInstance()); - }, - [this](llvm::orc::VModuleKey K) { return symbol_resolver_; }), + object_layer_(execution_session_, + [this](llvm::orc::VModuleKey) { + llvm::orc::RTDyldObjectLinkingLayer::Resources result; + result.MemMgr = + std::make_shared( + orc_jit_memory_mapper::GetInstance()); + result.Resolver = symbol_resolver_; + return result; + }), compile_layer_(object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, optimize_for_size, @@ -155,6 +114,23 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, << " features: " << target_machine_->getTargetFeatureString().str(); } +llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { + if (const uint8* from_constant_pool = + external_constant_pool_.Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + if (func_addr == nullptr) { + return nullptr; + } + llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), + llvm::JITSymbolFlags::None); + return symbol_info; +} + SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto key = execution_session_.allocateVModule(); @@ -169,19 +145,13 @@ void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { cantFail(compile_layer_.removeModule(key)); } -llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { - std::string mangled_name; - { - llvm::raw_string_ostream mangled_name_stream(mangled_name); - llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, data_layout_); - } - +llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { // Resolve symbol from last module to first, allowing later redefinitions of // symbols shadow earlier ones. for (auto& key : llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { if (auto symbol = - compile_layer_.findSymbolIn(key, mangled_name, + compile_layer_.findSymbolIn(key, name, /*ExportedSymbolsOnly=*/true)) { return symbol; } @@ -211,16 +181,25 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); + registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); + #undef REGISTER_CPU_RUNTIME_SYMBOL // Register both the f32 (float) and f64 (double) versions of a libm symbol. diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 50993afc8f73617a2c65310ae73b3ab00519f550..f4260a95bc45557b6cd969f7d3fff01c8b392575 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -46,9 +46,7 @@ namespace cpu { class SimpleOrcJIT { public: using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; - using CompileFtor = - std::function( - llvm::Module&)>; + using CompileFtor = std::function; using CompileLayerT = llvm::orc::IRCompileLayer; using VModuleKeyT = llvm::orc::VModuleKey; @@ -89,7 +87,7 @@ class SimpleOrcJIT { // Get the runtime address of the compiled symbol whose name is given. Returns // nullptr if the symbol cannot be found. - llvm::JITSymbol FindSymbol(const std::string& name); + llvm::JITSymbol FindCompiledSymbol(const std::string& name); llvm::TargetMachine* target_machine() const { return target_machine_.get(); } @@ -98,11 +96,12 @@ class SimpleOrcJIT { } private: + llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; const llvm::DataLayout data_layout_; - llvm::orc::SymbolStringPool string_pool_; llvm::orc::ExecutionSession execution_session_; std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 150db1cb6edec1af6724a8bca6a5f6272f1a7416..cd1165e23812861ba9951546b7dd744529232196 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -370,6 +370,9 @@ std::vector VectorSupportLibrary::ComputeHorizontalSums( std::vector VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( std::vector vectors, llvm::Value* init_values) { + // vectors are N llvm vector values, each with N elements. + int64 lane_width = vectors.size(); + while (vectors.size() != 2) { std::vector new_vectors; for (int i = 0; i < vectors.size(); i += 2) { @@ -390,10 +393,14 @@ VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( high = AddInternal(ExtractHighHalf(init_values), high); } + // `low` has the first `lane_width / 2` horizontal reductions, and `high` has + // the next `lane_width / 2` horizontal reductions. + std::vector results; - for (int i = 0; i < 8; i++) { + for (int i = 0; i < lane_width; i++) { llvm::Value* scalar_result = ir_builder()->CreateExtractElement( - i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + i < (lane_width / 2) ? low : high, + ir_builder()->getInt32(i % (lane_width / 2)), name()); results.push_back(scalar_result); } diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc new file mode 100644 index 0000000000000000000000000000000000000000..d938f3a2c4b5bfdd70d5a614b9890b4d7bf050f7 --- /dev/null +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/despecializer.h" + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/defuser.h" +#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" + +namespace xla { + +Despecializer::Despecializer() : pipeline_("despecializer") { + // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); + pipeline_.AddPass(); + pipeline_.AddPass(); +} + +StatusOr Despecializer::Run(HloModule* module) { + return pipeline_.Run(module); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h new file mode 100644 index 0000000000000000000000000000000000000000..af48f4ab6e506d295251239fe92db68cfec6dcfa --- /dev/null +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Creates an HloPassPipeline containing multiple HloPasses that can +// despecialize an optimized HloModule. This is useful to run an HloModule +// optimized for one specfic platform on a different platform (undoing platform +// specific passes) with matching numerics for comparison. +// +// Current despecialization passes are Defuser, ImplicitBroadcastRemover, +// and BFloat16MixedPrecisionRemoval. +class Despecializer : public HloPassInterface { + public: + Despecializer(); + tensorflow::StringPiece name() const override { return "despecializer"; } + StatusOr Run(HloModule* module) override; + + private: + HloPassPipeline pipeline_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index a803b3171f9afa6297553c5507c4f9aa45e420ab..3f7089d6ca1e1a3b9bb42028327ba54ba4b93974 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -190,6 +190,7 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; @@ -198,6 +199,7 @@ class DfsHloVisitorBase { virtual Status HandleReduce(HloInstructionPtr hlo) = 0; virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; + virtual Status HandleBroadcastDimOne(HloInstructionPtr hlo) = 0; virtual Status HandleReshape(HloInstructionPtr hlo) = 0; virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; virtual Status HandleParameter(HloInstructionPtr hlo) = 0; @@ -213,6 +215,7 @@ class DfsHloVisitorBase { virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; + virtual Status HandleGather(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 170adb3d241b3648bc53f96dde9866f0b794f80a..e6680ee9b87e1a01782204047c3b2104995c11ed 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -35,6 +35,12 @@ class HloInstruction; // DfsHloVisitor with default action based on the HloInstruction being visited. // Users should not use this class directly, but use the type aliases // DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. +// +// Do *not* add an override to this class if the opcode is covered by +// HandleElementwiseUnary/Binary. These opcode handlers dispatch to +// HandleElementwiseUnary/Binary in DfsHloVisitorBase. Adding such a handler +// here will break passes which rely on the HandleElementwiseUnary/Binary +// handling these opcodes. template class DfsHloVisitorWithDefaultBase : public DfsHloVisitorBase { @@ -70,12 +76,6 @@ class DfsHloVisitorWithDefaultBase Status HandleConcatenate(HloInstructionPtr concatenate) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstructionPtr convert) override { - return DefaultAction(convert); - } - Status HandleCopy(HloInstructionPtr copy) override { - return DefaultAction(copy); - } Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } @@ -91,9 +91,6 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleCompare(HloInstructionPtr compare) override { - return DefaultAction(compare); - } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } @@ -103,6 +100,9 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } + Status HandleHostCompute(HloInstructionPtr host_compute) override { + return DefaultAction(host_compute); + } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } @@ -158,6 +158,9 @@ class DfsHloVisitorWithDefaultBase Status HandleBroadcast(HloInstructionPtr broadcast) override { return DefaultAction(broadcast); } + Status HandleBroadcastDimOne(HloInstructionPtr broadcastDimOne) override { + return DefaultAction(broadcastDimOne); + } Status HandlePad(HloInstructionPtr pad) override { return DefaultAction(pad); } @@ -185,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSendDone(HloInstructionPtr send_done) override { return DefaultAction(send_done); } + Status HandleGather(HloInstructionPtr gather) override { + return DefaultAction(gather); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..825e1436f0ec6d49b555e5e3e9c2c7a19fb7b062 --- /dev/null +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class DfsHloVisitorWithDefaultTest : public HloTestBase {}; + +TEST_F(DfsHloVisitorWithDefaultTest, DefaultElementwiseTest) { + // Verify that HandleElementwiseBinary and HandleElementwiseUnary are called + // on the appropriate HLO ops (elementwise binary/unary ops). + + class ElementwiseTestVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo) override { + // The HLO should be neither an elementwise unary nor binary op. These + // cases are handled in HandleElementwiseBinary/Unary. + TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 2)) + << hlo->ToString(); + TF_RET_CHECK(!(hlo->IsElementwise() && hlo->operand_count() == 1)) + << hlo->ToString(); + return Status::OK(); + } + + Status HandleElementwiseBinary(HloInstruction* hlo) override { + // HLO should be elementwise binary. + TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 2) + << hlo->ToString(); + return Status::OK(); + } + Status HandleElementwiseUnary(HloInstruction* hlo) override { + // HLO should be elementwise unary. + TF_RET_CHECK(hlo->IsElementwise() && hlo->operand_count() == 1) + << hlo->ToString(); + return Status::OK(); + } + }; + + // HLO module contains are arbitrary mix of elementwise and non-elementwise + // operations. + const string& hlo_string = R"( +HloModule TestModule + +ENTRY TestComputation { + arg = f32[] parameter(0) + tuple = (f32[]) tuple(arg) + gte = f32[] get-tuple-element(tuple), index=0 + abs = f32[] abs(arg) + add = f32[] add(arg, gte) + broadcast = f32[42] broadcast(add), dimensions={} + slice = f32[0] slice(broadcast), slice={[1:2]} + copy = f32[] copy(arg) + eq = pred[] equal-to(arg, gte) + neg = f32[] negate(arg) + ROOT convert = f64[] convert(f32[] arg) +})"; + std::unique_ptr module = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()) + .ConsumeValueOrDie(); + ElementwiseTestVisitor visitor; + TF_EXPECT_OK(module->entry_computation()->Accept(&visitor)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4468adbadbf823f1420a8b665a26f66cb7d36b43..b6a0903b0eeaa04d8bc1488378c148b2016c5d48 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -226,7 +226,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (primitive_util::IsIntegralType(to_type)) { return ir_builder_->CreateIntCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(to_type)); + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -1003,6 +1003,30 @@ StatusOr ElementalIrEmitter::EmitReducePrecision( ir_builder_); } +static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, + llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* shift_result, + bool saturate_to_sign_bit) { + llvm::IntegerType* integer_type = + llvm::cast(lhs->getType()); + unsigned integer_bitsize = integer_type->getBitWidth(); + llvm::ConstantInt* integer_bitsize_constant = + llvm::ConstantInt::get(integer_type, integer_bitsize); + llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0); + llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1); + llvm::Value* saturated_value; + if (saturate_to_sign_bit) { + saturated_value = ir_builder->CreateSelect( + ir_builder->CreateICmpSLT(lhs, zero), minus_one, zero); + } else { + saturated_value = zero; + } + llvm::Value* shift_amt_in_range = + ir_builder->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); + return ir_builder->CreateSelect(shift_amt_in_range, shift_result, + saturated_value); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -1050,12 +1074,27 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); - case HloOpcode::kShiftLeft: - return ir_builder_->CreateShl(lhs_value, rhs_value); + + // Shifting out bits >= the number of bits in the type being shifted + // produces a poison value in LLVM which is basically "deferred undefined + // behavior" -- doing something observable with such a value precipitates + // UB. We replace the poison value with a constant to avoid this deferred + // UB. case HloOpcode::kShiftRightArithmetic: - return ir_builder_->CreateAShr(lhs_value, rhs_value); + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateAShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/true); + case HloOpcode::kShiftLeft: + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateShl(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: - return ir_builder_->CreateLShr(lhs_value, rhs_value); + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateLShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -1483,15 +1522,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kBroadcast: return [this, hlo, &operand_to_generator]( const IrArray::Index& target_index) -> StatusOr { + const HloInstruction* operand = hlo->operand(0); // The `dimensions` member of the broadcast instruction maps from // input dimensions to output dimensions. - const HloInstruction* operand = hlo->operand(0); - int64 rank = ShapeUtil::Rank(operand->shape()); - IrArray::Index source_index(rank); - for (int64 i = 0; i < rank; ++i) { - source_index[i] = target_index[hlo->dimensions(i)]; - } - return operand_to_generator.at(operand)(source_index); + return operand_to_generator.at( + operand)(target_index.SourceIndexOfBroadcast( + hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_)); }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( @@ -1683,6 +1719,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(if_data.after_block, ir_builder_); return ir_builder_->CreateLoad(ret_value_addr); }; + case HloOpcode::kBitcast: + CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), + ShapeUtil::ElementsIn(hlo->operand(0)->shape())); + return [this, hlo, &operand_to_generator](const IrArray::Index& index) { + const HloInstruction* operand = hlo->operand(0); + return operand_to_generator.at(operand)(index.SourceIndexOfBitcast( + hlo->shape(), operand->shape(), ir_builder_)); + }; case HloOpcode::kReshape: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), ShapeUtil::ElementsIn(hlo->operand(0)->shape())); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 90481c7a88f90edea5399ee44aee2d2c77fc115f..471d2fd6cebcd7a00dfea4aca08da08af534b05f 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" @@ -79,6 +80,7 @@ StatusOr> Executable::ExecuteOnStreamWrapper( StatusOr> return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); + TF_RETURN_IF_ERROR(return_value.status()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 0aee535ee780ef000bc5e9963ff48786b3a61eb2..a157235f8af6ea64a488510e427bbae502c46ca6 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -109,14 +108,6 @@ class Executable { return execution_profile_; } - // Returns Status::ok() if the two executables are equal to each other. - // - // An error status is returned otherwise. - virtual const Status EqualOrFail(const Executable& executable) { - return Unimplemented( - "Equality test on this executable is not implemented."); - } - const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); return *hlo_profile_printer_data_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 2b6caa149439a86d6d047605099bc3ff7b295a8e..85409b330b11537158059dcce8c2a96c98d38f30 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -93,7 +93,7 @@ Status FlattenNode(const CallGraphNode& node) { auto current = worklist.back(); worklist.pop_back(); for (auto* instruction : current->instructions()) { - if (GetInstructionCallContext(instruction) != + if (GetInstructionCallContext(instruction->opcode()) != CallContext::kSequential) { continue; } diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc new file mode 100644 index 0000000000000000000000000000000000000000..221ff7900f398166c193c495848a2afcfd4edc81 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -0,0 +1,392 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +using tensorflow::gtl::ArraySlice; + +static StatusOr TransposeIndexVectorDimToLast( + HloInstruction* gather_indices, int64 index_vector_dim) { + const Shape& gather_indices_shape = gather_indices->shape(); + if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { + return gather_indices; + } + std::vector permutation; + permutation.reserve(gather_indices_shape.dimensions_size()); + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(gather_indices, permutation); +} + +// If the gather_indices holds scalar indices (i.e. gather_indices has rank N +// and index_vector_dim is N) then reshape it to have a trailing degenerate +// dimension. This makes the code for slicing out the index vector more +// uniform. +static StatusOr DeScalarizeGatherIndices( + HloInstruction* gather_indices, int64 index_vector_dim) { + const Shape& gather_indices_shape = gather_indices->shape(); + if (index_vector_dim != gather_indices_shape.dimensions_size()) { + return gather_indices; + } + + DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size()); + + std::vector result_shape_dims; + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(result_shape_dims)); + result_shape_dims.push_back(1); + + return MakeReshapeHlo(result_shape_dims, gather_indices); +} + +// Canonicalizes the gather_indices tensors so that we only have deal with some +// specific cases in the while loop that does the heavy lifting. +// +// See the "High Level Algorithm" section for a broader picture. +static StatusOr CanonicalizeGatherIndices( + HloInstruction* gather_indices, int64 index_vector_dim) { + // If gather_indices holds scalar indices, normalize it to hold index vectors + // of size 1. + TF_ASSIGN_OR_RETURN( + HloInstruction * descalarized_gather_indices, + DeScalarizeGatherIndices(gather_indices, index_vector_dim)); + + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN(HloInstruction * transposed_gather_indices, + TransposeIndexVectorDimToLast(descalarized_gather_indices, + index_vector_dim)); + + // If there is only one index (i.e. gather_indices has rank 1 and this gather + // is really just a dynamic slice) add a leading degenerate dimension for + // uniformity. Otherwise create a "collapsed" leading dimension that subsumes + // all of the non-index-vector dimensions. + const Shape& shape = transposed_gather_indices->shape(); + if (shape.dimensions_size() == 1) { + return ExpandFirstDimIntoNDims(transposed_gather_indices, + {1, shape.dimensions(0)}); + } else { + return CollapseFirstNDims(transposed_gather_indices, + shape.dimensions_size() - 1); + } +} + +// Expands out or contracts away the gather dimensions in the accumulator +// produced by the while loop. +static StatusOr AdjustGatherDimsInAccumulator( + const Shape& gather_indices_shape, HloInstruction* accumulator, + int64 index_vector_dim) { + std::vector output_gather_dim_bounds; + output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + } + } + + if (output_gather_dim_bounds.empty()) { + // If output_gather_dim_bounds is empty we must be lowering a (effectively) + // dynamic-slice. In that case, there is a leading degenerate gather + // dimension that we added to make this special case play well with the + // general while loop which we need to remove now. + CHECK_EQ(accumulator->shape().dimensions(0), 1); + ArraySlice reshaped_dim_sizes = + AsInt64Slice(accumulator->shape().dimensions()); + reshaped_dim_sizes.remove_prefix(1); + return MakeReshapeHlo(reshaped_dim_sizes, accumulator); + } + + return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); +} + +// Expand an index vector from the gather_indices tensor into a vector that can +// be used to dynamic-slice out of the gather operand. +static StatusOr ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.gather_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// This generates the body of the while that implements the main data movement +// behavior of gather using dynamic-slice and dynamic-update-slice. +static StatusOr> GatherLoopBody( + const HloInstruction& gather, HloInstruction* induction_var, + const std::vector& incoming_loop_state) { + CHECK_EQ(incoming_loop_state.size(), 3); + HloInstruction* const operand = incoming_loop_state[0]; + HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const output_accumulator = incoming_loop_state[2]; + + int64 index_vector_size = gather_indices->shape().dimensions(1); + + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_gather_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + {1, index_vector_size})); + + TF_ASSIGN_OR_RETURN(HloInstruction * index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + + TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_start, + ExpandIndexVectorIntoOperandSpace( + index_vector, gather.gather_dimension_numbers(), + operand->shape().dimensions_size())); + + TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, + MakeDynamicSliceHlo(operand, gathered_slice_start, + gather.gather_window_bounds())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * gathered_slice_for_update, + ExpandFirstDimIntoNDims(gathered_slice, + {1, gathered_slice->shape().dimensions(0)})); + + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_into_accumulator, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/gathered_slice->shape().dimensions_size())); + + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_accumulator, + MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, + index_vector_into_accumulator)); + + // New loop state -- only the accumulator has changed. The + // WhileUtil::MakeCountedLoop functions takes care of the induction variable + // and the while loop exit condition. + return StatusOr>{ + {operand, gather_indices, updated_accumulator}}; +} + +static StatusOr CreateGatherLoopAccumulatorInitValue( + HloComputation* computation, PrimitiveType element_type, + ArraySlice window_bounds, int64 gather_loop_trip_count) { + std::vector accumulator_state_shape_dims; + accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.push_back(gather_loop_trip_count); + c_copy(window_bounds, std::back_inserter(accumulator_state_shape_dims)); + return BroadcastZeros(computation, element_type, + accumulator_state_shape_dims); +} + +static StatusOr ElideWindowDimsFromAccumulator( + HloInstruction* accumulator, const GatherDimensionNumbers& dim_numbers) { + std::vector dims_to_elide; + dims_to_elide.reserve(dim_numbers.elided_window_dims_size()); + for (int64 elided_window_dim : dim_numbers.elided_window_dims()) { + dims_to_elide.push_back(elided_window_dim + 1); + } + + return ElideDegenerateDims(accumulator, dims_to_elide); +} + +// `accumulator` is almost the tensor the gather operation would have produced, +// except that it has the dimensions in the wrong order -- the gather dimensions +// are the major dimensions and the window dimensions are the minor dimensions. +// Fix this up with a transpose. +static StatusOr PermuteGatherAndWindowDims( + HloInstruction* accumulator, ArraySlice output_window_dims, + int64 output_rank) { + std::vector permutation; + permutation.reserve(output_rank); + + int64 gather_idx_counter = 0; + int64 window_idx_counter = output_rank - output_window_dims.size(); + for (int64 i = 0; i < output_rank; i++) { + bool is_window_dim = c_binary_search(output_window_dims, i); + if (is_window_dim) { + permutation.push_back(window_idx_counter++); + } else { + permutation.push_back(gather_idx_counter++); + } + } + + return MakeTransposeHlo(accumulator, permutation); +} + +// High Level Algorithm +// +// We follow the following steps in sequence: +// +// 1. We canonicalize the gather_indices tensor such that it has rank +// 2 (i.e. is a matrix) where each row is an index vector into the +// operand. +// 2. We iterate over the set of indices in the canonicalized +// gather_indices tensor using a while loop, accumulating slices +// of the operand tensor into an accumulator using +// DynamicUpdateSlice. +// 3. The accumulator result from the while loop from (2) is then +// reshaped to split out all the individual gather dimensions and +// then transposed to give the final result. +// +// As an example, if we started with the following operation: +// +// HloModule TensorFlowGatherMultipleBatchDims +// +// ENTRY main { +// operand = s32[3,3] parameter(0) +// indices = s32[2,2] parameter(1) +// ROOT gather = s32[2,3,2] gather(operand, indices), +// output_window_dims={1}, +// elided_window_dims={1}, +// gather_dims_to_operand_dims={1}, +// index_vector_dim=2, +// window_bounds={3, 1} +// } +// +// We'd first reshape indices to s32[4,1], where each row is an index +// into operand. We'd then run a loop to slice out 4 tensors of shape +// [3,1] out of operand into an accumulator of shape [4,3,1]. We then +// reshape this result to [2,2,3] and finally transpose it to [2,3,2]. + +StatusOr GatherExpander::ExpandGather( + HloInstruction* gather_instr) { + CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + + HloComputation* computation = gather_instr->parent(); + HloInstruction* operand = gather_instr->mutable_operand(0); + HloInstruction* gather_indices = gather_instr->mutable_operand(1); + const Shape& gather_indices_shape = gather_indices->shape(); + const Shape& output_shape = gather_instr->shape(); + int64 output_rank = output_shape.dimensions_size(); + + const GatherDimensionNumbers& dim_numbers = + gather_instr->gather_dimension_numbers(); + + int64 gather_loop_trip_count = 1; + for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + gather_loop_trip_count *= gather_indices_shape.dimensions(i); + } + } + + if (!IsInt32(gather_loop_trip_count)) { + return Unimplemented( + "Gather operations with more than 2147483647 gather indices are not " + "supported. This error occurred for %s.", + gather_instr->ToString().c_str()); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, + CanonicalizeGatherIndices( + gather_indices, dim_numbers.index_vector_dim())); + + CHECK_EQ(gather_loop_trip_count, + canonical_gather_indices->shape().dimensions(0)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_init, + CreateGatherLoopAccumulatorInitValue( + computation, output_shape.element_type(), + gather_instr->gather_window_bounds(), gather_loop_trip_count)); + + StatusOr> gather_loop_result_or_error = + WhileUtil::MakeCountedLoop( + computation, gather_loop_trip_count, + {operand, canonical_gather_indices, accumulator_init}, + [&](HloInstruction* indvar, + const std::vector& loop_state) { + return GatherLoopBody(*gather_instr, indvar, loop_state); + }); + + TF_ASSIGN_OR_RETURN(std::vector gather_loop_result, + gather_loop_result_or_error); + + HloInstruction* accumulator_result = gather_loop_result.back(); + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_with_window_dims_elided, + ElideWindowDimsFromAccumulator(accumulator_result, dim_numbers)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * accumulator_with_output_gather_dims_decanonicalized, + AdjustGatherDimsInAccumulator(gather_indices->shape(), + accumulator_with_window_dims_elided, + dim_numbers.index_vector_dim())); + + return PermuteGatherAndWindowDims( + accumulator_with_output_gather_dims_decanonicalized, + AsInt64Slice(dim_numbers.output_window_dims()), output_rank); +} + +StatusOr GatherExpander::Run(HloModule* module) { + auto is_nontrivial_gather = [](HloInstruction* inst) { + return inst->opcode() == HloOpcode::kGather && + // Avoid expanding gather ops that produce zero sized tensors, + // instead punt these to ZeroSizedHloElimination. + !ShapeUtil::HasZeroElements(inst->shape()); + }; + + std::vector gather_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), + is_nontrivial_gather); + } + + for (HloInstruction* inst : gather_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandGather(inst)); + TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); + } + + return !gather_instrs.empty(); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h new file mode 100644 index 0000000000000000000000000000000000000000..c1fc8574da99fff223c7dbb570b4533f76905b9a --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass rewrites gather operations into (roughly) while loops of dynamic +// slices. This lets backends that don't support gather directly to +// nevertheless have a minimum level of support. +class GatherExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "gather_expander"; } + StatusOr Run(HloModule* module) override; + + private: + StatusOr ExpandGather(HloInstruction* gather_instr); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GATHER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba41ee8428cbe7132103df24d552565a8dc2f9f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { +TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2147483647,5] parameter(1) + ROOT gather = s32[2147483647,3,5] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + Status status = GatherExpander{}.Run(module.get()).status(); + EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + + ASSERT_THAT( + status.error_message(), + ::testing::HasSubstr("Gather operations with more than 2147483647 gather " + "indices are not supported.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 78dc0ad4fcd167c93f19d0c2b18ea72d666897ef..a99e2b7794a399047fb5a77a140bd333214e3f23 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -38,14 +38,7 @@ namespace xla { GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size) - : platform_id_(platform_id), pointer_size_(pointer_size) { - // We currently only support kHostPlatformId for CPU, kCudaPlatformId for - // GPU and kInterpreterPlatformId for Interpreter. Before supporting other - // platforms, we need to test this transfer manager on them. - CHECK(platform_id_ == se::host::kHostPlatformId || - platform_id_ == se::interpreter::kInterpreterPlatformId || - platform_id_ == se::cuda::kCudaPlatformId); -} + : platform_id_(platform_id), pointer_size_(pointer_size) {} se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9da4fb97fa27a238fead74985cb481a9be1f4a65..f1707442fe3354d5183d905468810f3871146ff5 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -241,6 +241,7 @@ cc_library( "gpu_executable.cc", "infeed_thunk.cc", "kernel_thunk.cc", + "memset_thunk.cc", "sequential_thunk.cc", "thunk_schedule.cc", "tuple_thunk.cc", @@ -257,6 +258,7 @@ cc_library( "gpu_executable.h", "infeed_thunk.h", "kernel_thunk.h", + "memset_thunk.h", "sequential_thunk.h", "thunk.h", "thunk_schedule.h", @@ -273,6 +275,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -293,6 +296,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cudnn_plugin", "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep + "//tensorflow/stream_executor", ], ) @@ -397,6 +401,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -437,8 +442,10 @@ tf_cc_test( ":fusion_merger", ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -452,6 +459,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", ], @@ -510,9 +518,11 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -690,17 +700,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 1792893ae401bf16d2dd9e861607e8f3821a505e..1eccfe8571ceb5b082f2b47473a38d7405d790b7 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -94,11 +94,17 @@ se::port::StatusOr> ScratchAllocator::AllocateBytes( // Determines whether we can safely perform a winograd non-fused convolution for // the given input and output shapes. This works around b/68264959, an integer // overflow in cuDNNv5 and cuDNNv6. -// -// TODO(jlebar): We shouldn't need this check for cuDNNv7. -bool ShouldIncludeWinogradNonfusedAlgo( - const Shape& input_shape, const Shape& output_shape, - const ConvolutionDimensionNumbers& dnums) { +bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, + const Shape& output_shape, + const ConvolutionDimensionNumbers& dnums, + se::StreamExecutor* stream_exec) { + // Skip this check for cudnn7 and newer. + auto version = + stream_exec->AsDnn()->GetVersion(); + if (version.ok() && version.ValueOrDie().major_version() >= 7) { + return true; + } + int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); @@ -118,20 +124,20 @@ bool ShouldIncludeWinogradNonfusedAlgo( std::vector GetAlgorithms(CudnnConvKind kind, bool with_winograd_nonfused, - se::StreamExecutor* stream_exec_) { + se::StreamExecutor* stream_exec) { std::vector algorithms; switch (kind) { case CudnnConvKind::kBackwardFilter: - CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms( + CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( with_winograd_nonfused, &algorithms)); break; case CudnnConvKind::kBackwardInput: - CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms( + CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( with_winograd_nonfused, &algorithms)); break; case CudnnConvKind::kForward: - CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); + CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, + &algorithms)); break; } @@ -209,8 +215,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( return nullopt; } - const bool use_winograd_nonfused = - ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); + const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( + input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index c137fbc97e29e24edb3603c611a5c8f093bc62a6..3cd30b754c3242f00c704de1afab2282ed827b41 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -45,6 +45,7 @@ void MaybeResolveTupleElements(HloInstruction* instruction, // Returns the bytes read by fusion parameter 'param', by returning the byte // size of 'param' shape (or the cumulative byte sizes of all leaf tuple // elements if 'param' is tuple-shaped). +// // In the special case where all users of 'param' (or all users of a leaf // tuple element if 'param' is tuple-shaped) are Slice instructions, the size // of each slice instruction is accumulated instead, to give a more accurate @@ -63,11 +64,10 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (std::all_of(instruction->users().begin(), instruction->users().end(), - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -199,6 +199,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { ++total_visited_; // Skip 'fusion' instruction if there are no users into which we can merge. if (fusion->users().empty()) { + VLOG(3) << "Not merging " << fusion->name() << ": Has no users."; ++num_fail_no_users_; return Status::OK(); } @@ -208,24 +209,27 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Input fusion instructions need to be rooted at a particular HLO (e.g. // kReduce), so they shouldn't be further fused either. if (fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + VLOG(3) << "Not merging " << fusion->name() << ": Is not loop fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); } // Skip multiple output fusion. It's not yet supported. if (fusion->IsMultiOutputFusion()) { + VLOG(3) << "Not merging " << fusion->name() << ": Is multi-output fusion."; ++num_fail_not_loop_fusion_; return Status::OK(); } // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!std::all_of(fusion->users().begin(), fusion->users().end(), - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() == - HloInstruction::FusionKind::kLoop; - })) { + if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kFusion && + (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Some of its users are not loop/input fusion kernels."; ++num_fail_merge_all_users_; return Status::OK(); } @@ -233,18 +237,17 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if any of its fused instructions are expensive. // This is done to avoid the duplication of expensive instructions, which // would occur if 'fusion' were merged into multiple users. + // // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (!std::all_of(fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), - [](const HloInstruction* instruction) { - if (instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction)) { - return false; - } - return true; - })) { + if (c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { + VLOG(3) << "Not merging " << fusion->name() + << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; return Status::OK(); } @@ -253,6 +256,8 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // exceeds the threshold value. if (CalculateFlopsToBytesRatio(fusion) > FusionMerger::GetThresholdFlopsToBytesRatio()) { + VLOG(3) << "Not merging " << fusion->name() + << ": flops-to-bytes ratio is not favorable."; ++num_fail_flops_to_byte_ratio_; return Status::OK(); } @@ -265,6 +270,9 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { const double merged_to_current_bytes_ratio = merged_bytes_transferred / std::max(1.0, current_bytes_transferred); if (merged_to_current_bytes_ratio > 1.10) { + VLOG(3) << "Not merging " << fusion->name() + << ": merged-to-current-bytes-ratio of " + << merged_to_current_bytes_ratio << " is not favorable."; ++num_fail_net_bytes_transferred_ratio_; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index deef5966b80d1b7f16e9982eed9ac5c7131e9d73..2217776c7d5a5f92c520d56222988f80401be9e4 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -16,257 +16,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { namespace { -class FusionMergerTest : public HloTestBase { - protected: - FusionMergerTest() : module_(CreateNewModule()) {} - - // Builds the following computation: - // - // Param - // / | \ - // / | \ - // OnesVec GTE(0) GTE(1) GTE(2) - // \ / \ / - // Add Add OnesVec - // \ / \ / - // \ Add Mul OnesVec - // \ | | / - // \ Mul Add - // \ | / - // \ | / - // Tuple - // - HloComputation* BuildComputation0() { - auto builder = HloComputation::Builder(TestName() + ".Computation0"); - // Create param instruction to access computation state. - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape3_, "param")); - - // Create GetTupleElement instructions for each tuple element. - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 1)); - auto gte2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, param, 2)); - - // Create const vector of ones to be used in element-wise computations. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // Create simple fusable computation for tuple element 0 (wont get merged). - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, one_vec, gte0)); - - // Create fusable computation which is dependent on second and third tuple - // elements (will initially be fused on its own). - auto add1 = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte1, gte2)); - - // Create two sub-computations, both of which are users of 'add1'. - - // First sub-computation: out1 = Mul(Add(add1, one_vec), one_vec) - auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, add1, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add2, one_vec)); - - // Second sub-computation: out2 = Add(Mul(add1, one_vec), one_vec) - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add1, one_vec)); - auto out2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1, out2})); - return module_->AddEntryComputation(builder.Build()); - } - - // Builds the following computation: - // - // Param - // / \ - // GTE(0) GTE(1) - // | | \ / - // | | Mul - // \ \ | - // \ Mul - // \ | - // OnesVec Mul OnesVec - // \ / \ / - // OnesVec Add Mul OnesVec - // \ | | / - // Mul Add - // \ / - // \ / - // Tuple - // - HloComputation* BuildComputation1() { - auto builder = HloComputation::Builder(TestName() + ".Computation1"); - Shape tuple_shape2_ = ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); - // Create param instruction to access computation state. - auto state = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape2_, "state")); - - // Create shared sub-computation (will initially be fused on its own). - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 2)); - // Calculate the flops we need to generate for this shared computation - // to exceed the threshold flops_to_bytes_ratio. - // Note that bytes transferred is multiplied by 3 because there are two - // operands and one output of size 'data_shape_'. - const int64 flops_needed = FusionMerger::GetThresholdFlopsToBytesRatio() * - ShapeUtil::ByteSizeOf(data_shape_) * 3; - const int64 vec_elements = ShapeUtil::ElementsIn(data_shape_); - const int64 iters = (flops_needed + vec_elements - 1) / vec_elements; - - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, gte0, gte1)); - for (int i = 0; i < iters; ++i) { - mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, gte0, mul0)); - } - - // Create two sub-computations, both of which are users of 'mul0'. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // First sub-computation: out0 = Mul(Add(mul0, one_vec), one_vec) - auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add0, one_vec)); - - // Second sub-computation: out1 = Add(Mul(mul0, one_vec), one_vec) - auto mul1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, mul0, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul1, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_->AddEntryComputation(builder.Build()); - } - - // Builds the following computation: - // - // Param - // / | | \ - // / | | \ - // / | | \ - // GTE(0) GTE(1) GTE(2) GTE(3) - // \ / / / - // Add / / - // \ / / - // Add / - // \ / - // \ / - // OnesVec Add OnesVec - // \ / \ / - // OnesVec Add Mul OnesVec - // \ | | / - // Mul Add - // \ / - // \ / - // Tuple - // - HloComputation* BuildComputation2(bool add_extra_input) { - auto builder = HloComputation::Builder(TestName() + ".Computation2"); - Shape state_shape = add_extra_input ? tuple_shape4_ : tuple_shape3_; - // Create param instruction to access computation state. - auto state = builder.AddInstruction( - HloInstruction::CreateParameter(0, state_shape, "state")); - - // Create GetTupleElement instructions for each tuple element. - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 1)); - auto gte2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 2)); - - // Create shared fusable computation that reduces its operands. - auto reduce0 = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1)); - auto reduce_out = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce0, gte2)); - if (add_extra_input) { - auto gte3 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, state, 3)); - reduce_out = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce_out, gte3)); - } +namespace op = xla::testing::opcode_matchers; - // Create two fusable sub-computations which are dependent on shared - // computation 'reduce_out'. - auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); - - // First sub-computation: out0 = Mul(Add(reduce_out, one_vec), one_vec) - auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, reduce_out, one_vec)); - auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, add2, one_vec)); - - // Second sub-computation: out1 = Add(Mul(reduce_out, one_vec), one_vec) - auto mul0 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kMultiply, reduce_out, one_vec)); - auto out1 = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kAdd, mul0, one_vec)); - - // Create output Tuple. - builder.AddInstruction(HloInstruction::CreateTuple({out0, out1})); - return module_->AddEntryComputation(builder.Build()); - } - - Shape data_shape_ = ShapeUtil::MakeShape(F32, {4}); - Shape tuple_shape2_ = ShapeUtil::MakeTupleShape({data_shape_, data_shape_}); - Shape tuple_shape3_ = - ShapeUtil::MakeTupleShape({data_shape_, data_shape_, data_shape_}); - Shape tuple_shape4_ = ShapeUtil::MakeTupleShape( - {data_shape_, data_shape_, data_shape_, data_shape_}); - - std::unique_ptr module_; -}; +class FusionMergerTest : public HloTestBase {}; // Tests that we can merge a fusion instruction that is below threshold. // -// Original computation: -// -// Param -// / | \ -// / | \ -// OnesVec GTE(0) GTE(1) GTE(2) -// \ / \ / -// Add Add OnesVec -// \ / \ / -// \ Add Mul OnesVec -// \ | | / -// \ Mul Add -// \ | / -// \ | / -// Tuple -// -// Computation after fusion passes: -// -// Param -// / \ -// Fusion3 Fusion2 -// | / \ -// \ Fusion0 Fusion1 -// \ | / -// \ | / -// Tuple -// // Computation after fusion merger pass (Fusion2 is merged into Fusion0 and // Fusion1): // Param @@ -276,19 +40,50 @@ class FusionMergerTest : public HloTestBase { // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto computation = BuildComputation0(); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); - // Run fusion merger pass, which should merge the shared fusion instruction - // into its two users. - EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); - - auto* root = computation->root_instruction(); + auto module = tools::Parse(R"( +HloModule MergeSharedFusionInstruction + +comp.3 { + constant.param_0 = f32[4]{0} parameter(0) + param.param_1.2 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(1) + get-tuple-element.6 = f32[4]{0} get-tuple-element(param.param_1.2), index=0 + ROOT add.7 = f32[4]{0} add(constant.param_0, get-tuple-element.6) +} + +comp.2 { + param.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.4 = f32[4]{0} get-tuple-element(param.param_1.1), index=1 + get-tuple-element.5 = f32[4]{0} get-tuple-element(param.param_1.1), index=2 + ROOT add.6 = f32[4]{0} add(get-tuple-element.4, get-tuple-element.5) +} + +comp.1 { + add.1.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3) +} + +comp { + add.1.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1) + ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY MergeSharedFusionInstruction.Computation0 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + param = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.3 = f32[4]{0} fusion(constant, param), kind=kLoop, calls=comp.3 + fusion.4 = f32[4]{0} fusion(param), kind=kLoop, calls=comp.2 + fusion.5 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp.1 + fusion.6 = f32[4]{0} fusion(constant, fusion.4), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.5, fusion.6) +})") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + + auto* root = module->entry_computation()->root_instruction(); EXPECT_EQ(HloOpcode::kTuple, root->opcode()); // Check operand 0 (not merged). Should have 4 instructions. auto* operand0 = root->operand(0); @@ -307,156 +102,188 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { // Tests that we do not merge a fusion instruction that above flops to bytes // threshold. // -// Original computation: -// -// Param -// / \ -// GTE(0) GTE(1) -// | | \ / -// | | Mul -// \ \ | -// \ Mul -// \ | -// OnesVec Mul OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ | | / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes and fusion merger pass (Fusion2 is not -// merged because it exceeds the threshold flops to bytes ratio). -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - BuildComputation1(); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = tools::Parse(R"( +HloModule FlopsToBytesRatioThresholdExceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.3 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.4 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + multiply.29 = f32[4]{0} multiply(get-tuple-element.3, get-tuple-element.4) + multiply.30 = f32[4]{0} multiply(get-tuple-element.3, multiply.29) + multiply.31 = f32[4]{0} multiply(get-tuple-element.3, multiply.30) + multiply.32 = f32[4]{0} multiply(get-tuple-element.3, multiply.31) + multiply.33 = f32[4]{0} multiply(get-tuple-element.3, multiply.32) + multiply.34 = f32[4]{0} multiply(get-tuple-element.3, multiply.33) + multiply.35 = f32[4]{0} multiply(get-tuple-element.3, multiply.34) + multiply.36 = f32[4]{0} multiply(get-tuple-element.3, multiply.35) + multiply.37 = f32[4]{0} multiply(get-tuple-element.3, multiply.36) + multiply.38 = f32[4]{0} multiply(get-tuple-element.3, multiply.37) + multiply.39 = f32[4]{0} multiply(get-tuple-element.3, multiply.38) + multiply.40 = f32[4]{0} multiply(get-tuple-element.3, multiply.39) + ROOT multiply.41 = f32[4]{0} multiply(get-tuple-element.3, multiply.40) +} + +comp.1 { + multiply.12.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.3 = f32[4]{0} add(multiply.12.param_1.1, constant.param_1.3) + ROOT multiply.16 = f32[4]{0} multiply(add.3, constant.param_1.3) +} + +comp { + multiply.12.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.15 = f32[4]{0} multiply(multiply.12.param_1, constant.param_1.1) + ROOT add.2 = f32[4]{0} add(multiply.15, constant.param_1.1) +} + +ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the flops/bytes of the // shared fusion instruction exceeds the threshold ratio, and therefore // cannot be merged with other fusion instructions. - EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is exceeded. // -// Original computation: -// -// Param -// / | | \ -// / | | \ -// / | | \ -// GTE(0) GTE(1) GTE(2) GTE(3) -// \ / / / -// Add / / -// \ / / -// Add / -// \ / -// \ / -// OnesVec Add OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ | | / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes and fusion merger pass. Fusion2 is not -// merged because it exceeds the threshold bytes transferred. This is because -// the bytes read by Fusion2 (when replicated if the instruction is merged -// into Fusion0 and Fusion1) would exceed the bytes transferred threshold. -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is not merged because it exceeds the threshold bytes transferred. +// This is because the bytes read by Fusion2 (when replicated if the instruction +// is merged into Fusion0 and Fusion1) would exceed the bytes transferred +// threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - BuildComputation2(/*add_extra_input=*/true); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = tools::Parse(R"( +HloModule BytesTransferredThresholdExeceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.8 = f32[4]{0} get-tuple-element(state.param_1.1), index=1 + add.9 = f32[4]{0} add(get-tuple-element.7, get-tuple-element.8) + get-tuple-element.9 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + add.10 = f32[4]{0} add(add.9, get-tuple-element.9) + get-tuple-element.10 = f32[4]{0} get-tuple-element(state.param_1.1), index=3 + ROOT add.11 = f32[4]{0} add(add.10, get-tuple-element.10) +} + +comp.1 { + add.2.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.6 = f32[4]{0} add(add.2.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.6, constant.param_1.3) +} + +comp { + add.2.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.2.param_1, constant.param_1.1) + ROOT add.5 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY BytesTransferredThresholdExeceeded.Computation2 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would increase. - EXPECT_FALSE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); } // Tests that threshold for bytes transferred if merged is not exceeded. // -// Original computation: -// -// Param -// / | \ -// / | \ -// / | \ -// GTE(0) GTE(1) GTE(2) -// \ / / -// Add / -// \ / -// OnesVec Add OnesVec -// \ / \ / -// OnesVec Add Mul OnesVec -// \ / \ / -// Mul Add -// \ / -// \ / -// Tuple -// -// Computation after fusion passes: -// -// Param -// | -// Fusion2 -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// -// Computation after fusion merger pass (Fusion2 is merged into Fusion0 and -// Fusion1, because bytes read from Param by Fusion2 is reduced for this test -// which makes the merge operation into its operand below the bytes -// transferred threshold. -// -// Param -// / \ -// Fusion0 Fusion1 -// \ / -// Tuple -// +// Fusion2 is merged into Fusion0 and Fusion1, because bytes read from Param by +// Fusion2 is reduced for this test which makes the merge operation into its +// operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - BuildComputation2(/*add_extra_input=*/false); - // Run standard fusion passes. - EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/false) - .Run(module_.get()) - .ValueOrDie()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module_.get()) - .ValueOrDie()); + auto module = tools::Parse(R"( +HloModule BytesTransferredThresholdNotExeceeded + +comp.2 { + state.param_1.1 = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + get-tuple-element.5 = f32[4]{0} get-tuple-element(state.param_1.1), index=0 + get-tuple-element.6 = f32[4]{0} get-tuple-element(state.param_1.1), index=1 + add.7 = f32[4]{0} add(get-tuple-element.5, get-tuple-element.6) + get-tuple-element.7 = f32[4]{0} get-tuple-element(state.param_1.1), index=2 + ROOT add.8 = f32[4]{0} add(add.7, get-tuple-element.7) +} + +comp.1 { + add.1.param_1.1 = f32[4]{0} parameter(1) + constant.param_1.3 = f32[4]{0} parameter(0) + add.5 = f32[4]{0} add(add.1.param_1.1, constant.param_1.3) + ROOT multiply.3 = f32[4]{0} multiply(add.5, constant.param_1.3) +} + +comp { + add.1.param_1 = f32[4]{0} parameter(1) + constant.param_1.1 = f32[4]{0} parameter(0) + multiply.2 = f32[4]{0} multiply(add.1.param_1, constant.param_1.1) + ROOT add.4 = f32[4]{0} add(multiply.2, constant.param_1.1) +} + +ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { + constant = f32[4]{0} constant({1, 1, 1, 1}) + state = (f32[4]{0}, f32[4]{0}, f32[4]{0}) parameter(0) + fusion.2 = f32[4]{0} fusion(state), kind=kLoop, calls=comp.2 + fusion.3 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp.1 + fusion.4 = f32[4]{0} fusion(constant, fusion.2), kind=kLoop, calls=comp + ROOT tuple = (f32[4]{0}, f32[4]{0}) tuple(fusion.3, fusion.4) +})") + .ValueOrDie(); // Run fusion merger pass, which should detect that the net bytes transferred // (if merged) would not increase. - EXPECT_TRUE(FusionMerger().Run(module_.get()).ValueOrDie()); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); +} + +// Check that we're willing to merge f1_computation into f2_computation, even +// though f2 is an input fusion node. +TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { + auto module = tools::Parse(R"( + HloModule m + + f1_computation { + f1_p0 = f32[10]{0} parameter(0) + ROOT f1_root = f32[10]{0} add(f1_p0, f1_p0) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[10]{0} parameter(0) + f2_mul = f32[10]{0} multiply(f2_p0, f2_p0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_mul, f2_zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10]{0} parameter(0) + f1 = f32[10]{0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Fusion(op::Parameter())); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 8e3aebbc12b5e6d746700956b9743bc94db50167..38668ff455a44c7ef99b57b750f1a3b18a90bd2c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -49,7 +49,7 @@ struct MatrixDescriptor { // rhs_matrix, and stores the result to output_matrix. template bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); se::DeviceMemory lhs_data(lhs_matrix.data); @@ -65,7 +65,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, return stream ->ThenBlasGemm( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, &output_data, /*leading dim of output=*/output_matrix.num_rows) @@ -89,7 +89,7 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, template bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, + MatrixDescriptor output_matrix, double alpha, se::blas::ComputationType computation_type, se::blas::AlgorithmType algorithm, se::Stream* stream, se::blas::ProfileResult* output_profile_result) { @@ -108,11 +108,13 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, return stream ->ThenBlasGemmWithAlgorithm( lhs_transpose, rhs_transpose, output_matrix.num_rows, - output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/1.0, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows, - computation_type, algorithm, output_profile_result) + output_matrix.num_cols, /*size of reduce dim=*/k, + /*alpha=*/static_cast(alpha), lhs_data, + /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, + /*beta=*/static_cast(0.0f), &output_data, + /*leading dim of output=*/output_matrix.num_rows, computation_type, + algorithm, output_profile_result) .ok(); } @@ -125,8 +127,8 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, template StatusOr DoGemmAutotune( MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::blas::ComputationType computation_type, - se::Stream* stream) { + MatrixDescriptor output_matrix, double alpha, + se::blas::ComputationType computation_type, se::Stream* stream) { std::vector algorithms; CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms)); @@ -137,9 +139,9 @@ StatusOr DoGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - DCHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - computation_type, algorithm, stream, - &profile_result)); + CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, + alpha, computation_type, algorithm, + stream, &profile_result)); if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { @@ -161,6 +163,8 @@ StatusOr DoGemmAutotune( // DoGemm/DoGemmWithAlgorithm/DoGemmAutotune. auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { switch (type) { + case F16: + return &DoGemm; case F32: return &DoGemm; case F64: @@ -172,6 +176,8 @@ auto GetGemmFn(PrimitiveType type) -> decltype(&DoGemm) { auto GetGemmWithAlgorithmFn(PrimitiveType type) -> decltype(&DoGemmWithAlgorithm) { switch (type) { + case F16: + return &DoGemmWithAlgorithm; case F32: return &DoGemmWithAlgorithm; case F64: @@ -182,6 +188,8 @@ auto GetGemmWithAlgorithmFn(PrimitiveType type) } auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { switch (type) { + case F16: + return &DoGemmAutotune; case F32: return &DoGemmAutotune; case F64: @@ -196,6 +204,10 @@ auto GetGemmAutotuneFn(PrimitiveType type) -> decltype(&DoGemmAutotune) { // separately from the precision of the inputs and result. se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { switch (type) { + case F16: + // Use F32 as computation type for F16 as we currently only implement the + // cuDNN pseudo half configuration for half precision. + return se::blas::ComputationType::kF32; case F32: return se::blas::ComputationType::kF32; case F64: @@ -212,7 +224,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, const HloInstruction* hlo_instruction) + bool transpose_rhs, double alpha, + const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), @@ -221,7 +234,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, rhs_shape_(rhs_shape), output_shape_(output_shape), transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs) {} + transpose_rhs_(transpose_rhs), + alpha_(alpha) {} tensorflow::Status GemmThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { @@ -290,7 +304,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (autotune_it == autotune_results_.end()) { StatusOr best_algorithm = GetGemmAutotuneFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - computation_type, stream); + alpha_, computation_type, stream); autotune_it = autotune_results_.insert({device_name, best_algorithm}).first; @@ -311,12 +325,15 @@ tensorflow::Status GemmThunk::ExecuteOnStream( VLOG(2) << "Using algorithm " << algorithm << " chosen by autotuning on GemmThunk " << this; return GetGemmWithAlgorithmFn(element_type)( - lhs_matrix, rhs_matrix, output_matrix, computation_type, algorithm, - stream, + lhs_matrix, rhs_matrix, output_matrix, alpha_, computation_type, + algorithm, stream, /*output_profile_result=*/nullptr); } + + // Autotune will fail when CUDA 8 and GPU sm_50 or older are used. + // Use the older Gemm API in this case. return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, - stream); + alpha_, stream); }; bool launch_ok; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 8c6a1f51a8a09ef78950dfe7e89994a3fe247f49..df3edcefef898d465cd5ddc53e5d06a966a31f88 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -34,15 +34,16 @@ namespace gpu { // This is thread-compatible. class GemmThunk : public Thunk { public: - // Constructs a thunk that computes "output = lhs rhs" using BLAS gemm. - // transpose_lhs and transpose_rhs indicate whether gemm should transpose the - // lhs and rhs operand. hlo_instruction is as in Thunk. + // Constructs a thunk that computes "output = (lhs rhs) * alpha" using + // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should + // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is + // a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, - const HloInstruction* hlo_instruction); + double alpha, const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -72,6 +73,7 @@ class GemmThunk : public Thunk { const bool transpose_lhs_; const bool transpose_rhs_; + const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune // results. The map's value is the best algorithm we've found for this thunk diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 28ebd034ee0c89137f4e6eb417d8a37f4a00af7a..07be2a0cf90c326af6e41764e79950db546e43e4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -33,8 +33,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -164,6 +166,9 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, /*rewrite_grad_op=*/true, /*use_fusion=*/false); + // Rewrite gather ops into smaller ones. + pass.AddPass(); + // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pipeline.AddPass(); @@ -176,6 +181,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, pass.AddPass(); pass.AddPass(); pass.AddPass(); + pass.AddPass(); } pipeline.AddPass( @@ -241,6 +247,22 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout()); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); + pipeline.AddPass(/*is_layout_sensitive=*/true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { HloPassFix fusion("fusion"); fusion.AddInvariantChecker(); @@ -277,15 +299,6 @@ tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { HloPassPipeline pipeline("GPU-ir-emit-prepare"); pipeline.AddInvariantChecker(); - pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - [](const Shape&, const Shape&) { return true; }); - pipeline.AddPass(/*is_layout_sensitive=*/true); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which // materializes the value) or missing a necessary copy (later pass removes an @@ -658,6 +671,8 @@ StatusOr> GpuCompiler::RunBackend( if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + cost_analysis.set_bytes_per_second( + stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 916b556fd43a453a4da2c96217e74c367f8c7653..9db85bc788bde46c890a46ce9b0902ddce3f5675 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -49,7 +49,7 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants // in IR. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 623d6714de501000e38b7698620925f66425f157..28f93447953b90d8a7fa4386e2355066c0405aec 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -46,12 +46,14 @@ namespace { class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, - se::Stream* stream, - const HloComputation* computation) + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) : do_profile_(do_profile), profile_(profile), stream_(stream), + sub_streams_(sub_streams), computation_(computation) { if (do_profile_) { clock_rate_ghz_ = @@ -70,6 +72,7 @@ class HloExecutionProfiler { CHECK(!finished_execution_) << "Call FinishExecution only once!"; finished_execution_ = true; if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(execution_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( @@ -88,6 +91,7 @@ class HloExecutionProfiler { // that the hlo_instruction took to execute in the profile. void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(per_op_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( @@ -100,6 +104,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; @@ -147,13 +152,9 @@ Status GpuExecutable::ExecuteThunks( LOG(WARNING) << "PROFILING: profiling is enabled"; } - HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, - hlo_module_->entry_computation()); - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // Stream 0 indicates `main_stream` and substreams start from stream 1. std::vector::SmartPtr> sub_streams; + sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); TF_ASSIGN_OR_RETURN( @@ -161,6 +162,10 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, + sub_streams, hlo_module_->entry_computation()); + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + // The next event enqueued on stream N must not run until the thunk at // last_blocking_thunk_for_stream[N] completes. std::map last_blocking_thunk_for_stream; @@ -262,16 +267,22 @@ StatusOr> GpuExecutable::ExecuteOnStream( ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { - // The caller must give us a buffer for ShapeIndex {} of every parameter. - // It can optionally give us a buffer for other ShapeIndices, but we - // ignore them: Because we can't rely on these sub-buffers' addresses - // being available, our generated code can't use them. Instead, it must - // chase pointers starting at the tuple root. - if (allocation.param_shape_index().empty()) { - auto param_no = allocation.parameter_number(); - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->root_buffer()); + auto param_no = allocation.parameter_number(); + se::DeviceMemoryBase buffer = + arguments[param_no]->buffer(allocation.param_shape_index()); + + // All top-level buffers and sub-buffers must have an explicit, non-null + // pointer, except for zero-sized buffers, which may be null. + if (buffer.is_null() && buffer.size() > 0) { + return FailedPrecondition( + "Cannot run XLA computation because pointer to (sub-)buffer at " + "index %s of parameter %lld was null. All pointers to " + "(sub-)buffers must not be null, unless the (sub-)buffer has zero " + "elements.", + allocation.param_shape_index().ToString().c_str(), param_no); } + + buffer_allocations_builder.RegisterBuffer(i, buffer); } } se::StreamExecutor* executor = run_options->stream()->parent(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index b19cfd43debd0a5490495d176fa2f1fcd625da07..dcb3991f41a31db84d8e9e555ae7d13c3ac84b97 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -83,11 +83,6 @@ class GpuExecutable : public Executable { const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; - const Status EqualOrFail(const Executable& executable) { - // TODO(b/62952745) Implement equality test on GPU executable. - return Unimplemented("Equality test on GPU executable is not implemented."); - } - private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index b5962f069bf499c913bd5479f263a7cb77c00555..85ecbe8fdb34700ca738b99ddd9ea615afc35da3 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -25,13 +25,19 @@ namespace gpu { namespace { bool IsFusile(const HloInstruction& hlo) { + // Don't fuse get-tuple-element on GPU: We can, but it's slower than not + // fusing. We never generate kernels for unfused GTEs. Instead, if an + // unfused GTE is an input to a kernel (including a fusion kernel), we + // compute the address of the GTE at the top of the kernel. Often we know the + // address of the GTE result statically, so we can do this without chasing any + // pointers. return (hlo.IsElementwise() && hlo.operand_count() > 0) || + hlo.opcode() == HloOpcode::kBitcast || hlo.opcode() == HloOpcode::kBroadcast || hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || - hlo.opcode() == HloOpcode::kGetTupleElement || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || @@ -46,6 +52,34 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Check if we can use output fusion for (A @ B) * alpha + if (producer->opcode() == HloOpcode::kDot) { + if (consumer->opcode() == HloOpcode::kMultiply) { + CHECK_EQ(consumer->operand_count(), 2); + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + if (alpha->opcode() == HloOpcode::kConstant && + ShapeUtil::IsScalar(alpha->shape())) { + return true; + } + } + } + + // Only allow to fuse transpose into an output fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { + if (producer->opcode() != HloOpcode::kTranspose) { + return false; + } + // Check that the transpose is the operand of a dot. + auto producer_operand_index = consumer->operand_index(producer); + auto fused_parameter = consumer->fused_parameter(producer_operand_index); + const std::vector& fused_parameter_users = + fused_parameter->users(); + return (fused_parameter_users.size() == 1 && + fused_parameter_users[0]->opcode() == HloOpcode::kDot); + } + // Output fusion is not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; @@ -70,17 +104,6 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // We may need to know original operand layout to emit input fusion, and so - // far, we merely use the layout of an operand of the fusion node, which means - // we must fuse only elementwise operations. This restriction should be lifted - // later if we need to fuse other operations, e.g. transpose, for performance. - if ((IsReductionToVector(*consumer) || - (HloOpcode::kFusion == consumer->opcode() && - HloInstruction::FusionKind::kInput == consumer->fusion_kind())) && - !producer->IsElementwise()) { - return false; - } - // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && @@ -98,6 +121,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } + if (producer->opcode() == HloOpcode::kDot) { + return HloInstruction::FusionKind::kOutput; + } if (HloOpcode::kFusion == consumer->opcode()) { return consumer->fusion_kind(); } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 2d6dad27a59978da6e4719afc50ebee5e641dde0..4b231c449f8f101127b4d30bfff20c69d8cef5c1 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; @@ -137,30 +138,119 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, GetTupleElementFused) { - HloComputation::Builder builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "param")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, param, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, param, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); +// Tests that broadcasts fused into a fusion with a reduce root. +TEST_F(InstructionFusionTest, BroadcastIntoReduce) { + auto module = tools::Parse(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY BroadcastIntoReduce { + constant = f32[] constant(1) + broadcast = f32[16,16,16,16]{3,2,1,0} broadcast(constant), dimensions={} + constant.1 = f32[] constant(0) + ROOT reduce = f32[] reduce(broadcast, constant.1), dimensions={0,1,2,3}, + to_apply=add + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); +} + +TEST_F(InstructionFusionTest, BitcastIntoAdd) { + auto module = tools::Parse(R"( + HloModule test_module + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + bitcast = f32[4,1]{1,0} bitcast(p0) + ROOT add = f32[4,1] add(bitcast, p1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Bitcast(op::Parameter()), op::Parameter())); +} + +TEST_F(InstructionFusionTest, AddIntoBitcast) { + auto module = tools::Parse(R"( + HloModule test_module + + ENTRY BroadcastIntoAdd { + p0 = f32[4,1,1]{2,1,0} parameter(0) + p1 = f32[4,1]{1,0} parameter(1) + add = f32[4,1] add(p0, p1) + ROOT bitcast = f32[4,1,1] bitcast(add) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Bitcast(op::Add(op::Parameter(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, DontFuseGTE) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY DontFuseGTE { + p0 = (f32[10], f32[10]) parameter(0) + gte0 = f32[10] get-tuple-element(p0), index=0 + gte1 = f32[10] get-tuple-element(p0), index=1 + ROOT add = f32[10] add(gte0, gte1) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DotOutputFusion) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + constant = f32[] constant(3) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} + dot = f32[4,4]{1,0} dot(p0, transpose) + ROOT mul = f32[4,4] multiply(constant, dot) + })") + .ValueOrDie(); + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) .ValueOrDie()); - HloInstruction* root = computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - HloInstruction* fused_root = root->fused_expression_root(); - EXPECT_EQ(HloOpcode::kAdd, fused_root->opcode()); - // Check that operands of 'fused_root' are GTE. - EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(0)->opcode()); - EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT( + root->fused_expression_root(), + op::Multiply(op::Parameter(), + op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 2f65edffea81db7dba1f8545f92b27ea622044e7..532d436ee82b985a4efe300f90223e1298e85765 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -49,8 +49,10 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. - bool type_is_allowed = (output_shape.element_type() == F32 || - output_shape.element_type() == F64); + PrimitiveType output_primitive_type = output_shape.element_type(); + bool type_is_allowed = + (output_primitive_type == F16 || output_primitive_type == F32 || + output_primitive_type == F64); return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape) && @@ -87,6 +89,19 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return true; } + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput && + hlo.fused_expression_root()->opcode() == HloOpcode::kMultiply) { + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = hlo.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo.fused_expression_root()->operand(1); + } + if (dot->opcode() == HloOpcode::kDot) { + return ImplementedAsGemm(*dot); + } + } + return false; } @@ -145,14 +160,19 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn - // algorithm to use. It's up to a later pass to choose the algorithm, so to - // indicate that we haven't yet made a choice, we speicfy -1 for that arg. + // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn + // algorithm to use, and a boolean indicating whether to use tensor cores. + // + // It's up to a later pass to choose the algorithm and decide whether to use + // tensor cores, so to indicate that we haven't yet made a choice, we speicfy + // -1 and false for those args. HloInstruction* negative_one = computation->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction* false_constant = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloInstruction* custom_call = computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one}, call_target)); + call_shape, {lhs, rhs, negative_one, false_constant}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a3df67a87344d6ece2ea9047321ad9542c13f8cf..1e0db2821a2c212d0f212ae94ab69231bc6053ea 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -438,6 +439,32 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { return IrEmitter::DefaultAction(select); } +namespace { +llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { + return ir_builder->CreateExtractValue(x, {0}); +} + +llvm::Value* Imag(llvm::Value* x, llvm::IRBuilder<>* ir_builder) { + return ir_builder->CreateExtractValue(x, {1}); +} + +std::pair MultiplyComplex( + llvm::Value* lhs_value, llvm::Value* rhs_value, + llvm::IRBuilder<>* ir_builder) { + llvm::Value* lhs_real = Real(lhs_value, ir_builder); + llvm::Value* lhs_imag = Imag(lhs_value, ir_builder); + llvm::Value* rhs_real = Real(rhs_value, ir_builder); + llvm::Value* rhs_imag = Imag(rhs_value, ir_builder); + llvm::Value* real_result1 = ir_builder->CreateFMul(lhs_real, rhs_real); + llvm::Value* real_result2 = ir_builder->CreateFMul(lhs_imag, rhs_imag); + llvm::Value* real_result = ir_builder->CreateFSub(real_result1, real_result2); + llvm::Value* imag_result1 = ir_builder->CreateFMul(lhs_real, rhs_imag); + llvm::Value* imag_result2 = ir_builder->CreateFMul(lhs_imag, rhs_real); + llvm::Value* imag_result = ir_builder->CreateFAdd(imag_result1, imag_result2); + return {real_result, imag_result}; +} +} // namespace + Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); @@ -456,21 +483,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { - auto real = [&](llvm::Value* x) { - return ir_builder_.CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_.CreateExtractValue(x, {1}); - }; - llvm::Value* real_result = ir_builder_.CreateFSub( - ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value))); - llvm::Value* imag_result = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value))); + auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = ir_builder_.CreateInsertValue(result, real_result, {0}); - result = ir_builder_.CreateInsertValue(result, imag_result, {1}); + result = ir_builder_.CreateInsertValue(result, value.first, {0}); + result = ir_builder_.CreateInsertValue(result, value.second, {1}); } else { result = ir_builder_.CreateFMul(lhs_value, rhs_value); } @@ -548,20 +564,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm::Value* accum = ir_builder_.CreateLoad(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { -#define REAL(x) ir_builder_.CreateExtractValue(x, {0}) -#define IMAG(x) ir_builder_.CreateExtractValue(x, {1}) - llvm::Value* product_real = ir_builder_.CreateFSub( - ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)), - ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element))); - llvm::Value* product_imag = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)), - ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element))); - updated_accum = ir_builder_.CreateInsertValue( - accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0}); - updated_accum = ir_builder_.CreateInsertValue( - updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1}); -#undef IMAG -#undef REAL + auto value = MultiplyComplex(lhs_element, rhs_element, &ir_builder_); + llvm::Value* accum_real = Real(accum, &ir_builder_); + llvm::Value* real_sum = ir_builder_.CreateFAdd(accum_real, value.first); + updated_accum = ir_builder_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* accum_imag = Imag(accum, &ir_builder_); + llvm::Value* imag_sum = ir_builder_.CreateFAdd(accum_imag, value.second); + updated_accum = ir_builder_.CreateInsertValue(updated_accum, imag_sum, {1}); } else { llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); updated_accum = ir_builder_.CreateFAdd(accum, product); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index aa2a0a9800bab142481e1def785c9052526fcd8c..26e497762f2a6f23767c5b98f339eefdef0b7468 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include @@ -44,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -142,37 +145,6 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } -// Tries to get a Slice for the given instruction at the given index, but -// returns nullopt if we might not know the slice's address at runtime without -// dereferencing a containing tuple. -// -// In particular, when XLA accepts a parameter of tuple type, the caller has the -// option of telling XLA what are the values inside of the tuple, or just giving -// XLA a pointer to the top-level tuple and letting us chase the pointers on the -// GPU. We therefore cannot rely having these pointers to parameter sub-buffers -// being present when we run the program. -optional GetKnownAtRuntimeSlice( - const HloInstruction* instr, const ShapeIndex& index, - const BufferAssignment& buffer_assn) { - auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); - if (!maybe_slice.ok()) { - return nullopt; - } - // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, - // but we don't necessarily know the runtime address of sub-buffers of input - // parameters. - const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); - const BufferAllocation* alloc = slice.allocation(); - if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && - !alloc->param_shape_index().empty()) { - return nullopt; - } - - // Otherwise, we will know the address of this slice at runtime without having - // to dereference a tuple. - return slice; -} - } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -203,7 +175,7 @@ bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -213,13 +185,13 @@ bool ImplementedAsDeviceToDeviceMemcpy( // // 1. `hlo` is a kCopy instruction. // 2. `hlo` and its operand have the same shape (thus the same layout too). - // 3. The operand to `hlo` has a buffer assignment (constants do not, for - // instance) which means the source buffer also resides on the device. + // 3. `hlo` and its operand have a statically-known buffer assignment + // (constants do not, for instance), which means the source buffer also + // resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && - GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) - .has_value(); + buffer_assignment.GetUniqueTopLevelSlice(&hlo).ok() && + buffer_assignment.GetUniqueTopLevelSlice(hlo.operand(0)).ok(); } } // namespace @@ -498,12 +470,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { switch (root->opcode()) { case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(fusion)); - TF_RETURN_IF_ERROR(EmitInitializer( - fusion, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(fusion)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; @@ -517,39 +488,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); Shape input_shape = root->operand(0)->shape(); - // EmitReductionToVector requires the input shape to have a layout, but - // fused instructions don't have one. So we determine its layout from - // the fusion's operands. The choice of the layout only affects - // performance but not correctness. - auto choose_input_layout = []( - tensorflow::gtl::ArraySlice operands, - Shape* input_shape) -> Status { - // Prefer the layout of an operand whose shape is compatible with - // input_shape. - for (const HloInstruction* operand : operands) { - if (ShapeUtil::Compatible(*input_shape, operand->shape())) { - return LayoutUtil::CopyLayoutBetweenShapes(operand->shape(), - input_shape); - } - } - // If no operand has a compatible shape, prefer an operand that has - // the same rank at least. - for (const HloInstruction* operand : operands) { - if (ShapeUtil::Rank(*input_shape) == - ShapeUtil::Rank(operand->shape())) { - // Do not use CopyLayoutBetweenShapes because input_shape and - // operand->shape() may be incompatible. - *input_shape->mutable_layout() = operand->shape().layout(); - return Status::OK(); - } - } - // When all the above fails, which is rare, set the default layout. - LayoutUtil::SetToDefaultLayout(input_shape); - return Status::OK(); - }; - TF_RETURN_IF_ERROR( - choose_input_layout(fusion->operands(), &input_shape)); - return EmitReductionToVector( root, input_shape, fused_emitter.GetGenerator(root->operand(0)), fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), @@ -598,7 +536,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); + + int max_unroll_factor = fusion->GetModule() + ->config() + .debug_options() + .xla_gpu_max_kernel_unroll_factor(); + + // Find the largest possible power of two to unroll by. + // TODO(kramerb): Make this smarter. + int unroll_factor = 1; + if (!fusion->IsMultiOutputFusion()) { + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int64 num_elements = ShapeUtil::ElementsIn(fusion->shape()); + for (int i = max_unroll_factor; i > 1; i /= 2) { + if (num_elements % i == 0) { + unroll_factor = i; + break; + } + } + } + + thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); } @@ -1668,14 +1626,14 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { if (IsReductionToVector(*reduce) && // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(reduce)); std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(reduce)); - TF_RETURN_IF_ERROR(EmitInitializer( - reduce, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(reduce)); + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(reduce)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), reduce)); + return EmitReductionToVector( reduce, input->shape(), [&](const llvm_ir::IrArray::Index& index) { @@ -1739,16 +1697,13 @@ Status IrEmitterUnnested::HandleSelectAndScatter( CHECK_EQ(rank, ShapeUtil::Rank(source->shape())); CHECK_EQ(rank, window.dimensions_size()); - { - std::vector> thunks; - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - TF_RETURN_IF_ERROR(EmitInitializer( - select_and_scatter, static_cast(thunks.back().get()))); - bindings_.UnbindAllLocalIrValues(); - thunks.emplace_back(BuildKernelThunk(select_and_scatter)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(select_and_scatter)); + std::vector> thunks; + thunks.push_back(std::move(initializer_thunk)); + thunks.push_back(BuildKernelThunk(select_and_scatter)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1993,38 +1948,54 @@ GetHloBufferSlices(const HloInstruction* hlo, -> optional> { // Simple, common case: Is the buffer for instr known at runtime? If so, // we're done. - auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); - if (slice.has_value()) { - return {{*slice, ShapeIndex()}}; + auto slice = buffer_assn.GetUniqueSlice(instr, index); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } - // If we don't know the buffer for instr at index, see if we know the buffer - // for instr at index without its last element. If so, we can dynamically - // find the buffer for instr by dereferencing a pointer in that buffer. - // Continue looking this way until we run out of elements in 'index'. - ShapeIndex new_index = index; - ShapeIndex gte_indices; - while (!new_index.empty()) { - gte_indices.push_front(new_index.back()); - new_index.pop_back(); - auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + // If that didn't work, walk up any bitcasts that we might see. These must + // appear before any GTE instructions, because it's illegal to bitcast to a + // tuple type. + const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kBitcast) { + parent = parent->operand(0); + + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), ShapeIndex()}}; } } - // If *that* didn't work, check whether instr is a GTE instruction. If it - // is, see if we can get a buffer for its parent, and continue walking up - // parents until we find a defined buffer or we hit something that's not a - // GTE. - const HloInstruction* parent = instr; + // Check whether instr is a GTE instruction. If it is, see if we can get a + // buffer for its parent, and continue walking up parents until we find a + // defined buffer or we hit something that's not a GTE. + ShapeIndex gte_indices; while (parent->opcode() == HloOpcode::kGetTupleElement) { gte_indices.push_front(parent->tuple_index()); parent = parent->operand(0); - auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); - if (slice.has_value()) { - return {{*slice, gte_indices}}; + auto slice = buffer_assn.GetUniqueSlice(parent, {}); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; + } + } + + // Finally, if we don't know the buffer for instr at index, see if we know + // the buffer for instr at index without its last element. If so, we can + // dynamically find the buffer for instr by dereferencing a pointer in that + // buffer. Continue looking this way until we run out of elements in + // 'index'. + // + // We can almost always get a buffer without resorting to this. The only + // exception is for cases where the relevant sub-buffer is truly unknowable, + // for example the sub-buffer of a tuple-shaped select. + ShapeIndex new_index = index; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = buffer_assn.GetUniqueSlice(instr, new_index); + if (slice.ok()) { + return {{slice.ValueOrDie(), gte_indices}}; } } @@ -2064,8 +2035,13 @@ GetHloBufferSlices(const HloInstruction* hlo, return slices; } -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst) { +Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { + // TODO(b/72710576): Gather is not implemented on GPUs + return Unimplemented("Gather is not implemented on GPUs."); +} + +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( + const HloInstruction* inst, int unroll_factor) { const BufferAssignment& buffer_assn = ir_emitter_context_->buffer_assignment(); @@ -2157,7 +2133,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( } return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), - inst); + inst, unroll_factor); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -2216,31 +2192,63 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( inst->shape(), // The shape of the output. false, // Do not transpose LHS. false, // Do not transpose RHS. + 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - const HloInstruction* dot = inst->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS. - inst); + if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) { + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); + } + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*mul), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + dot->operand(0)->IsRank2Transpose(), // Transpose LHS. + dot->operand(1)->IsRank2Transpose(), // Transpose RHS. + alpha->literal().Get({0}), // alpha. + inst); + } else { + const HloInstruction* dot = inst->fused_expression_root(); + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + dot->operand(0)->IsRank2Transpose(), // Transpose LHS. + dot->operand(1)->IsRank2Transpose(), // Transpose RHS. + 1.0, // Alpha. + inst); + } } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2256,37 +2264,87 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( /*output_shape=*/inst->shape(), inst); } -Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, - KernelThunk* thunk) { +StatusOr> IrEmitterUnnested::BuildInitializerThunk( + const HloInstruction* hlo) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - CHECK(inst->opcode() == HloOpcode::kSelectAndScatter || - inst->opcode() == HloOpcode::kReduce); - const HloInstruction* init_value = nullptr; - switch (inst->opcode()) { - case HloOpcode::kSelectAndScatter: - init_value = inst->operand(2); - break; - case HloOpcode::kReduce: - init_value = inst->operand(1); - break; - default: - LOG(FATAL) << "Opcode " << inst->opcode() - << " should not need an initializer."; - } + const HloInstruction* init_value = [&] { + switch (inst->opcode()) { + case HloOpcode::kSelectAndScatter: + return inst->operand(2); + case HloOpcode::kReduce: + return inst->operand(1); + default: + LOG(FATAL) << "Opcode " << inst->opcode() + << " should not need an initializer."; + } + }(); if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } - return EmitTargetElementLoopInThunk( + // In the common case, the initializer is a constant. In this case, emit a + // device-memset call if we can. Currently StreamExecutor only supports + // zeroing and 32-bit memsets. + if (init_value->IsConstant()) { + CHECK(ShapeUtil::IsScalar(init_value->shape())); + int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape()); + const auto& literal = init_value->literal(); + + // Are all the bytes of this scalar equal to 0? If so, we can create a + // MemzeroThunk. + ArraySlice literal_bytes( + reinterpret_cast(literal.untyped_data()), num_bytes); + if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + } + + // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by + // repeating the literal 4 or 2 times, so long as the destination buffer is + // an even multiple of 32 bits long. + if ((num_bytes == 1 || num_bytes == 2) && + ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + uint16 pattern16; + if (num_bytes == 1) { + uint8 b = literal_bytes.front(); + pattern16 = uint16{b} | (uint16{b} << 8); + } else { + pattern16 = literal_bytes.front(); + } + uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); + return {MakeUnique(pattern32, + GetAllocationSlice(*hlo), hlo)}; + } + + // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit + // memset so long as all 32-bit words of the scalar are equal to each other. + if (num_bytes >= 4 && num_bytes % 4 == 0 && + memcmp(literal_bytes.data(), literal_bytes.data() + 4, + literal_bytes.size() - 4) == 0) { + uint32 word; + memcpy(&word, literal_bytes.data(), sizeof(word)); + return {MakeUnique(word, GetAllocationSlice(*hlo), + hlo)}; + } + } + + // Otherwise fall back to our slow initializer code. + std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); + TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, - thunk); + kernel_thunk.get())); + + // Clean up state left behind by emitting the loop above. (This is normally + // done in IrEmitterUnnested::Postprocess().) + bindings_.UnbindAllLocalIrValues(); + + // Convert unique_ptr to StatusOr>. + return {std::move(kernel_thunk)}; } namespace { @@ -2447,21 +2505,28 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + int unroll_factor = thunk->unroll_factor(); VLOG(3) << bindings_.ToString(); const Shape& element_shape = hlo.IsMultiOutputFusion() ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); + VLOG(3) << "EmitTargetElementLoopInThunk " + << ShapeUtil::HumanStringWithLayout(hlo.shape()) + << " for unroll_factor " << unroll_factor; LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->device_description()); + element_shape, ir_emitter_context_->device_description(), unroll_factor); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, unroll_factor) .EmitLoop(IrName(&hlo)); } + CHECK_EQ(unroll_factor, 1) + << "multi-output fusion does not support unrolling"; + // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 688760efbd2c725a4bf48e45eb6f2734b63d25e1..b842f480c6257c1a8bee8cdac55e29c5db6801a0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -67,6 +67,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; + Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; @@ -147,13 +148,12 @@ class IrEmitterUnnested : public IrEmitter { tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reducer); - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned - // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + // Thunk object. The kernel implementation will be unrolled if unroll_factor + // is greater than one. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst, + int unroll_factor = 1); // Returns a FftThunk that calls cuFFT to implement `inst`. std::unique_ptr BuildFftThunk(const HloInstruction* inst); @@ -162,6 +162,11 @@ class IrEmitterUnnested : public IrEmitter { // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + // Returns a thunk that, given a reduce or select-and-scatter op, initializes + // its memory to the appropriate initial value. + StatusOr> BuildInitializerThunk( + const HloInstruction* hlo); + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index c20a781a33fe89af4740ed31dd5bfb1a64473057..c24dc1457f83c7557430a69baf806ed05b45adca 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -30,10 +30,12 @@ namespace gpu { KernelThunk::KernelThunk( tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction) + const string& kernel_name, const HloInstruction* hlo_instruction, + int unroll_factor) : Thunk(Kind::kKernel, hlo_instruction), args_(args.begin(), args.end()), - kernel_name_(kernel_name) {} + kernel_name_(kernel_name), + unroll_factor_(unroll_factor) {} tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { tensorflow::mutex_lock lock(mutex_); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 9ae455e2fcc253a7a08ff95764721048a16b0bf7..df8971b083fe70588f8c32f977981e365d78fdb8 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -47,12 +47,14 @@ class KernelThunk : public Thunk { // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. KernelThunk(tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction); + const string& kernel_name, const HloInstruction* hlo_instruction, + int unroll_factor); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; ~KernelThunk() override = default; const string& kernel_name() const { return kernel_name_; } + int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); tensorflow::Status Initialize(const GpuExecutable& executable) override; @@ -69,6 +71,10 @@ class KernelThunk : public Thunk { // Entry kernel name for the computation. const string kernel_name_; + // The number of times this kernel should be unrolled. This works as a + // multiplier on the number of elements produced by a GPU thread. + const int unroll_factor_; + // The thread and block dimension used to launch the kernel. // Will be set by IrEmitterUnnested. LaunchDimensions launch_dimensions_; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index f4c4dcdafd6cc0cd64da5a8d1f23c8c0e7b2a9cb..86c4ac18b0501c38aaaae5a007bddcf261ca338f 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -68,17 +68,3 @@ tf_cc_test( "@llvm//:support", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index cfabae791d26d0eb49826085ad7ad166a19109a1..df9d9be889ce839ee665cd4820b169c124d9fcde 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -34,7 +34,7 @@ limitations under the License. #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.def" +#include "llvm/CodeGen/CommandFlags.inc" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -252,7 +252,7 @@ void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); } - llvm::WriteBitcodeToFile(&module, outfile.os()); + llvm::WriteBitcodeToFile(module, outfile.os()); outfile.keep(); } diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..18e673542c5b47cb90d31a8eff62a5e4adb78d1d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/stream_executor/stream_executor.h" + +namespace xla { +namespace gpu { + +namespace se = ::perftools::gputools; + +Status MemzeroThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemZero(&dest_data, dest_data.size()); + return Status::OK(); +} + +Status Memset32BitValueThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); + stream->ThenMemset32(&dest_data, value_, dest_data.size()); + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..b4bb74d1dd6dc9d09c5e4d439d57dfe8b57c2ed9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/stream_executor/stream_executor.h" + +// This file contains thunks that set a buffer's elements to a particular value. +// This can be faster than emitting a kernel to set the elements. + +namespace xla { +namespace gpu { + +// Thunk that zeroes out a given chunk of memory. +class MemzeroThunk : public Thunk { + public: + explicit MemzeroThunk(const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemzero, hlo), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const BufferAllocation::Slice dest_; +}; + +// Thunk that sets a given chunk of memory to a particular 32-bit value. The +// destination chunk must have size divisible by 32 bits. +class Memset32BitValueThunk : public Thunk { + public: + explicit Memset32BitValueThunk(uint32 value, + const BufferAllocation::Slice& dest, + const HloInstruction* hlo) + : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + uint32 value_; + const BufferAllocation::Slice dest_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 25846dc6cd4633c7becb6e62d6bc9585348a6eac..7bda4e2fcd469bd430e5ef1846251c8504225383 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -68,13 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput( HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - input = computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape( - /*operand_shape=*/input->shape(), - /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), - padding_config) - .ConsumeValueOrDie(), - input, padding, padding_config)); + input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } if (window_util::HasNegativePadding(conv_window)) { @@ -97,11 +92,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::max(0LL, -conv_window.dimensions(i).padding_high()); } - input = computation->AddInstruction(HloInstruction::CreateSlice( - ShapeInference::InferSliceShape(input->shape(), start_indices, - limit_indices, strides) - .ConsumeValueOrDie(), - input, start_indices, limit_indices, strides)); + input = + MakeSliceHlo(input, start_indices, limit_indices, strides).ValueOrDie(); } return input; @@ -134,13 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(element_type)))); - return computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape( - /*operand_shape=*/kernel->shape(), - /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}), - padding_config) - .ConsumeValueOrDie(), - kernel, padding, padding_config)); + return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -252,11 +238,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( computation->AddInstruction(HloInstruction::CreateConstant( MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = - computation->AddInstruction(HloInstruction::CreatePad( - ShapeInference::InferPadShape(input->shape(), padding->shape(), - input_padding_config) - .ConsumeValueOrDie(), - input, padding, input_padding_config)); + MakePadHlo(input, padding, input_padding_config).ValueOrDie(); // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 388dcc008b07a76ff9ed07df04181e49a8734f51..d8c07dc3119fb81a3ef22822acb11b7c4d5bbca5 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -32,25 +32,32 @@ namespace gpu { ParallelLoopEmitter::ParallelLoopEmitter( BodyEmitter body_emitter, const Shape& shape, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor) : LoopEmitter(body_emitter, shape, ir_builder), - launch_dimensions_(launch_dimensions) {} + launch_dimensions_(launch_dimensions), + unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor) : LoopEmitter(target_element_generator, target_arrays, ir_builder), - launch_dimensions_(launch_dimensions) {} + launch_dimensions_(launch_dimensions), + unroll_factor_(unroll_factor) {} ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder) + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor) : LoopEmitter(target_element_generator, target_array, ir_builder), - launch_dimensions_(launch_dimensions) {} + launch_dimensions_(launch_dimensions), + unroll_factor_(unroll_factor) {} -llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( +std::vector +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; @@ -63,6 +70,9 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x" // // %nctaid.x is currently specified as 2147483647. + VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; + std::vector array_indices; + llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), @@ -81,7 +91,7 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), "thread_id"); - llvm::Value* linear_index = ir_builder_->CreateAdd( + llvm::Value* linear_index_base = ir_builder_->CreateAdd( ir_builder_->CreateMul( block_id, ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", @@ -99,15 +109,30 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::assume, {ir_builder_->CreateICmpULT( - linear_index, + linear_index_base, ir_builder_->getInt64(launch_dimensions_.threads_per_block() * launch_dimensions_.block_count()), "linear_index_in_range")}, {}, ir_builder_); + if (unroll_factor_ > 1) { + linear_index_base = ir_builder_->CreateMul( + linear_index_base, ir_builder_->getInt64(unroll_factor_), + "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); + } + + array_indices.emplace_back(linear_index_base, shape_, ir_builder_); + for (int i = 1; i < unroll_factor_; ++i) { + llvm::Value* linear_index = ir_builder_->CreateAdd( + linear_index_base, ir_builder_->getInt64(i), "linear_index", + /*HasNUW=*/true, /*HasNSW=*/true); + array_indices.emplace_back(linear_index, shape_, ir_builder_); + } + auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( - linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), + linear_index_base, + ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); // Set exit_bb_ to the exit block of the if structure. @@ -116,7 +141,8 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // Set IR builder insertion point to the body of the if structure. llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - return llvm_ir::IrArray::Index(linear_index, shape_, ir_builder_); + + return array_indices; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 8ed63a854a74fc06c3c389f40fe1f5970885deac..25318b3bed8bf4a2dfe3a4a974269d0405c3bfec 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -34,13 +34,13 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // The meanings of other parameters are the same as LoopEmitter. ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); // Constructs a ParallelLoopEmitter from an element generator that generates // each element of the given target array. ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, const llvm_ir::IrArray& target_array, const LaunchDimensions& launch_dimensions, - llvm::IRBuilder<>* ir_builder); + llvm::IRBuilder<>* ir_builder, int unroll_factor = 1); // Constructs a loop emitter for a loop that generates on element of each of N // arrays on each iteration. @@ -50,18 +50,20 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder, + int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; - llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + std::vector EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) override; private: // The thread and block dimension to parallelize the loop on. const LaunchDimensions launch_dimensions_; + const int unroll_factor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 6cf280df05496716a0780d61ded92efd9982734c..5283d51cd10668c43c5ad1c1fb11049555bff5d4 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -44,12 +44,16 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc) { + const Shape& shape, const se::DeviceDescription& device_desc, + int unroll_factor) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } + CHECK_EQ(num_elements % unroll_factor, 0); + num_elements = num_elements / unroll_factor; + // Since we don't do any inter-warp communication, we're free to choose any // block size we want, subject to hardware constraints. We choose the // smallest block size that allows the GPU to reach full occupancy (assuming diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 0bf463a6ef95d5a32784838c08ad239752fd1acf..42d2d2af2e334da7c42419cb07a2bd5bb9d209d6 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -58,7 +58,8 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions( const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc); + const perftools::gputools::DeviceDescription& device_desc, + int unroll_factor = 1); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 2c3032d79be221e8cacb178ffb1817459b603cc0..9eea958d1214b131d49cb4e28f1944860408d3a8 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -51,6 +51,8 @@ class Thunk { kGemm, kInfeed, kKernel, + kMemset32BitValue, + kMemzero, kSequential, kTuple, kWhile, diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index cde5877e29f36abc61c5417ce960e2c7699e2749..3dd4c4a0794e5c41b877078c4e69c6c9584ce6c0 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -27,38 +27,6 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -namespace { - -// Returns the set of buffers that may be sources of all operands of the given -// instruction. The returned buffers are guaranteed to have no duplicates, and -// to be sorted in a deterministic order. -std::vector UniqueOperandSourceBuffers( - const HloInstruction* instruction, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector buffers; - for (const HloInstruction* operand : instruction->operands()) { - points_to_analysis.GetPointsToSet(operand).ForEachElement( - [&](const ShapeIndex& /*index*/, - const PointsToSet::BufferList& points_to) { - buffers.insert(buffers.end(), points_to.begin(), points_to.end()); - }); - } - - // Sort and then remove duplicates from buffers. - std::sort(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - buffers.erase(std::unique(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() == b->id(); - }), - buffers.end()); - return buffers; -} - -} // namespace - /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, @@ -93,6 +61,7 @@ Status HeapSimulator::RunComputation( const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { + VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -101,7 +70,51 @@ Status HeapSimulator::RunComputation( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. + // 'used_buffers' is the reverse map - it tracks which buffers were used by an + // instruction, so that we can remove the instructions from a buffer's live + // set after they are visited. FlatMap> live_buffers; + FlatMap> used_buffers; + auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( + const HloInstruction* user, + const LogicalBuffer* buffer) { + if (!IgnoreBuffer(buffer)) { + VLOG(4) << " Adding user " << user->name() << " to buffer " + << buffer->ToString(); + live_buffers[buffer].insert(user); + used_buffers[user].insert(buffer); + } + }; + + // Initialize live_buffers for each buffer that we're going to assign. The + // set of instructions that need to be visited contains all users of all + // aliases, that is, all users of all instructions that have the buffer + // contained in their points-to set. + for (const HloInstruction* instruction : instruction_sequence) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction); + const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); + for (const HloInstruction* user : instruction->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + for (const LogicalBuffer* buffer : buffer_set) { + add_user_to_buffer(user, buffer); + } + } else { + // A GetTupleElement doesn't need to keep all of its operand's buffers + // alive. It only needs the buffers that relate to the element its + // extracting, and the tuple it's extracting from, but not the buffers + // for the other elements. + for (const LogicalBuffer* buffer : points_to.element({})) { + add_user_to_buffer(user, buffer); + } + const PointsToSet& gte_points_to = + points_to_analysis.GetPointsToSet(user); + for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + add_user_to_buffer(user, buffer); + } + } + } + } const HloInstruction* root = computation.root_instruction(); auto output_source_buffers = @@ -114,34 +127,17 @@ Status HeapSimulator::RunComputation( buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); - // Initialize live_buffers for each buffer that we're going to assign. The - // set of instructions that need to be visited contains all users of all - // aliases. The alias itself is not necessary; if it has users, the users - // are necessarily scheduled after the alias. And if it has no users, it is - // either a dead value or an output, both of which are handled below. - // - // We ignore control dependencies here. The reasoning is that the control - // dependencies have already been accounted for in the ordering of the given - // 'instruction_sequence', and should not otherwise artificially extend the - // lifetime of buffers that aren't already connected by a data dependency. + VLOG(3) << "Instruction: " << instruction->ToString(); + for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + VLOG(4) << " Defines: " << buffer->ToString() + << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); + } + dead_buffers_to_free.clear(); for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } - FlatSet* live_set = nullptr; - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - const std::vector& users = - alias.instruction()->users(); - if (!users.empty()) { - if (live_set == nullptr) { - live_set = &live_buffers[buffer]; - } - live_set->insert(users.begin(), users.end()); - } - } - // Add a nullptr sentry to ensure entry parameters and output source // buffers are not freed until the very end. const bool entry_parameter = @@ -165,11 +161,12 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : - UniqueOperandSourceBuffers(instruction, points_to_analysis)) { + for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } + VLOG(4) << " Removing user " << instruction->name() << " from buffer " + << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); FlatSet* live_set = &it->second; live_set->erase(instruction); @@ -178,6 +175,11 @@ Status HeapSimulator::RunComputation( operand_buffers_to_free.push_back(operand_buffer); } } + // Sort to get a deterministic iteration order. + std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), + [](const LogicalBuffer* x, const LogicalBuffer* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -203,6 +205,8 @@ Status HeapSimulator::RunComputation( CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index(), points_to_analysis)) { + VLOG(3) << " Sharing: " << buffer->ToString() << " with " + << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; @@ -211,6 +215,7 @@ Status HeapSimulator::RunComputation( } if (!shared) { + VLOG(3) << " Allocating: " << buffer->ToString(); Alloc(buffer, instruction); } } @@ -225,6 +230,7 @@ Status HeapSimulator::RunComputation( // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { @@ -243,20 +249,34 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. for (const LogicalBuffer* buffer : dead_buffers_to_free) { + VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } for (const LogicalBuffer* buffer : operand_buffers_to_free) { + VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } } // Any remaining live buffers must be entry parameters or output source - // buffers, which had a nullptr sentry added. Free them now. + // buffers, which had a nullptr sentry added. Free them now, in a + // deterministic order. + std::vector to_free; + to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const LogicalBuffer* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; + to_free.push_back(buffer); + } + + std::sort(to_free.begin(), to_free.end(), + [](const LogicalBuffer* x, const LogicalBuffer* y) { + return x->id() < y->id(); + }); + for (const LogicalBuffer* buffer : to_free) { + VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 387b649a731ebcbfd8307807469f39f22d192b06..688a271712ac243666ba4ff02932aa4f7f7ed21c 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -410,6 +410,56 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, IndependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramB = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32scalar_, "paramB")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kMultiply, paramA, paramB)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kAdd, paramA, paramB)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + auto element0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0)); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec4_, element0, {0})); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kSubtract, paramA, paramB)); + auto element1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1)); + auto output = builder.AddInstruction( + HloInstruction::CreateTuple({broadcast, sub, element1})); + + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramB, mul, add, tuple, element0, + broadcast, sub, element1, output}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramB, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(add, {})}, + {kAlloc, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(broadcast, {})}, + // The mul can be freed right after the broadcast happens, even though + // The other GetTupleElement is still alive. + {kFree, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(sub, {})}, + // The temporary tuple is now dead. + {kFree, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(output, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramB, {})}, + {kFree, tracker.BufferAt(add, {})}, + {kFree, tracker.BufferAt(broadcast, {})}, + {kFree, tracker.BufferAt(sub, {})}, + {kFree, tracker.BufferAt(output, {})}, + {kFinish, nullptr}, + }); +} + TEST_F(HeapSimulatorTest, WholeModule) { HeapSimulatorTracker tracker(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 36db711c6c3570efdf678261ad38bbdb08cf94aa..8fd7f8945c7c36a451af30fcd5939a2498648e74 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// DO NOT USE THESE PROTO MESSAGES FOR ANYTHING OTHER THAN DEBUGGING. -// -// Don't use these protos in the real compilation or execution codepaths. The -// data format is meant for debugging only, and may change without notice. +// This proto file defines messages which represent the HLO module. This is a +// full fidelity serialization of the c++ HLO constructs. // // Many of the protos below are simple 1-to-1 serializations of the -// corresponding C++ classes. +// corresponding C++ classes, e.g., HloModule, HloComputation, and +// HloInstruction. // // FIELD NAMES ARE IMPORTANT // @@ -38,16 +37,19 @@ option cc_enable_arenas = true; message HloInstructionProto { reserved 10; reserved "parameter_name"; + reserved 12; + reserved "fused_instructions_computation"; + reserved 4; + reserved "operand_names"; + reserved 5; + reserved "control_predecessor_names"; + reserved 6; + reserved "called_computation_names"; string name = 1; string opcode = 2; xla.Shape shape = 3; - // TODO(b/67782397): Replace instruction names with HloInstruction ids. - repeated string operand_names = 4; - repeated string control_predecessor_names = 5; - repeated string called_computation_names = 6; - xla.OpMetadata metadata = 7; // Literal, only present for kConstant. @@ -58,7 +60,6 @@ message HloInstructionProto { // Fusion state, only present for kFusion. string fusion_kind = 11; - HloComputationProto fused_instructions_computation = 12; // Index for kGetTupleElement. int64 tuple_index = 13; @@ -129,28 +130,61 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + + // Gather dimension numbers. + xla.GatherDimensionNumbers gather_dimension_numbers = 33; + repeated int64 gather_window_bounds = 34; + + // Compute Host. + string channel_name = 41; + int64 cost_estimate_ns = 42; + + // The id of this instruction. + int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 control_predecessor_ids = 37; + repeated int64 called_computation_ids = 38; + + xla.OpSharding sharding = 40; } // Serialization of HloComputation. message HloComputationProto { + reserved 3; + reserved "root_name"; + string name = 1; // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; - // The name of the root of the computation. - string root_name = 3; + // The program shape (with layout) of this computation. + xla.ProgramShape program_shape = 4; + + // The id of this computation. + int64 id = 5; + + // The id of the root of the computation. + int64 root_id = 6; } // Serialization of HloModule. message HloModuleProto { string name = 1; string entry_computation_name = 2; + int64 entry_computation_id = 6; // The array of computations is always in a valid dependency order, where // callees appear before their callers. repeated HloComputationProto computations = 3; + + // The program shape (with layout) of the entry computation. + xla.ProgramShape program_shape = 4; + + // The id of this module. + int64 id = 5; } // Serialization of HloOrdering. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6d2a3aa5b531650a658502531e050702ffbd3760..a88283ed9a6459b4fa9310e160b59c77d51f1027 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -171,24 +171,21 @@ class BufferValueMap { return value_to_buffer_number_.at(&value); } - // Compute and return a vector of buffers that the given value must be - // contained in due to HLO aliasing rules. - std::vector ComputeAliasedBuffers(const HloValue& value) { + void ComputeWhileAliasedBuffers(const HloValue& value, + std::vector* aliased_buffers) { + VLOG(3) << "Compute kWhile aliases"; // Value is init of a while (use is while). - std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = dataflow_.GetUniqueValueAt(use.instruction, use.operand_index); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); VLOG(3) << " value is init value to a while; must share buffer with " "while value " << while_value.ToShortString(); } } - // Value is a parameter of a while body/condition. if (value.defining_instruction()->opcode() == HloOpcode::kParameter) { const HloComputation* computation = @@ -205,11 +202,10 @@ class BufferValueMap { VLOG(3) << " value is parameter value of the body or condition of a " "while; must share buffer with while value " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } - // Value is the root of a while body. for (const HloPosition& position : value.positions()) { const HloComputation* computation = position.instruction->parent(); @@ -224,27 +220,71 @@ class BufferValueMap { const HloValue& while_value = dataflow_.GetUniqueValueAt( callsite.instruction(), position.index); - VLOG(3) << " value is root the body computation of a while; must " - "share buffer with while value " + VLOG(3) << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; body root and while value root must share buffer " + "among them : " << while_value.ToShortString(); - aliased_buffers.push_back(GetBufferForValue(while_value)); + aliased_buffers->push_back(GetBufferForValue(while_value)); } } } } - // Value is the output of the while instruction itself. if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { VLOG(3) << " value is output of a while instruction"; - aliased_buffers.push_back(GetBufferForValue(value)); + aliased_buffers->push_back(GetBufferForValue(value)); + } + } + + void ComputeConditionalAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute kConditional aliases"; + // Aliases the buffers of the true/false computations roots, with the one of + // the conditional. + for (const HloPosition& position : value.positions()) { + const HloComputation* computation = position.instruction->parent(); + const CallGraphNode& call_graph_node = + dataflow_.call_graph().GetNode(computation); + if (position.instruction == computation->root_instruction()) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + // Call graph must have been flattened. + CHECK_EQ(call_graph_node.caller_callsites().size(), 1); + + const HloValue& cond_value = dataflow_.GetUniqueValueAt( + callsite.instruction(), position.index); + VLOG(3) + << " value @ " << position << " is root of " + << callsite.instruction()->name() + << "; true/false branch roots must share buffer among them : " + << cond_value.ToShortString(); + aliased_buffers->push_back(GetBufferForValue(cond_value)); + } + } + } + } + // Value is the output of the conditional instruction itself. + if (value.defining_instruction()->opcode() == HloOpcode::kConditional) { + VLOG(3) << " value is output of a conditional instruction"; + aliased_buffers->push_back(GetBufferForValue(value)); } + } + // Compute and return a vector of buffers that the given value must be + // contained in due to HLO aliasing rules. + std::vector ComputeAliasedBuffers(const HloValue& value) { + for (const HloUse& use : value.uses()) { + VLOG(2) << "Use of value " << value.ToShortString() << ": " << use; + } + std::vector aliased_buffers; + ComputeWhileAliasedBuffers(value, &aliased_buffers); + ComputeConditionalAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. std::sort(aliased_buffers.begin(), aliased_buffers.end()); aliased_buffers.erase( std::unique(aliased_buffers.begin(), aliased_buffers.end()), aliased_buffers.end()); - return aliased_buffers; } @@ -419,7 +459,7 @@ StatusOr> HloAliasAnalysis::Run( auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN( alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(module, /*ssa_form=*/true, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 5432419e4a2dd2916da32ac6566851bf52fd68ca..594413e88fb26e86b198d08b2e4db77fad671348 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -65,6 +65,7 @@ HloComputation::HloComputation( std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), + unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); @@ -101,7 +102,7 @@ HloInstruction* HloComputation::AddInstructionInternal( instruction->UniquifyName(&parent()->instruction_name_uniquer()); instruction->SetUniqueId(parent()->NewUniqueInstructionId()); } - Reparent(instruction.get()); + instruction->set_parent(this); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = instructions_.insert(instructions_.end(), std::move(instruction)); @@ -158,10 +159,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } -void HloComputation::Reparent(HloInstruction* instruction) { - instruction->set_parent(this); -} - bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for @@ -307,19 +304,15 @@ void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, std::list* post_order) { - if (visited->count(computation) > 0) { - return; - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* called_computation : - instruction->called_computations()) { - ComputeComputationPostOrder(called_computation, visited, post_order); + if (visited->insert(computation).second) { + for (auto* instruction : computation->instructions()) { + for (HloComputation* called_computation : + instruction->called_computations()) { + ComputeComputationPostOrder(called_computation, visited, post_order); + } } + post_order->push_back(computation); } - - visited->insert(computation); - post_order->push_back(computation); } } // namespace @@ -393,43 +386,46 @@ string HloComputation::ToString(const HloPrintOptions& options) const { HloComputationProto HloComputation::ToProto() const { HloComputationProto proto; + CHECK(unique_id_ != -1) + << "This computation does not have a valid id. Please make sure the " + "computation is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); for (const HloInstruction* instruction : MakeInstructionPostOrder()) { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } - proto.set_root_name(root_instruction()->name()); + proto.set_root_id(root_instruction()->unique_id()); + *proto.mutable_program_shape() = ComputeProgramShape(); return proto; } /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction) { + const tensorflow::gtl::FlatMap& computation_map) { std::vector> instructions; - tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, - HloInstruction::CreateFromProto( - module, instruction_proto, instruction_map, - computation_map, add_fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } - TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); - instruction_map[instruction->name()] = instruction.get(); + TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); + instruction_map[instruction_proto.id()] = instruction.get(); instructions.push_back(std::move(instruction)); } - TF_RET_CHECK(!proto.root_name().empty()); - TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); - HloInstruction* root = instruction_map.at(proto.root_name()); - return WrapUnique(new HloComputation( - proto.name(), parameter_count, &instructions, root, fusion_instruction)); + TF_RET_CHECK(proto.root_id() != -1); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); + HloInstruction* root = instruction_map.at(proto.root_id()); + return WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( @@ -509,13 +505,14 @@ StatusOr HloComputation::DeepCopyInstruction( "Can't deep copy instruction %s: instruction is not in computation %s", instruction->name().c_str(), name().c_str()); } - if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " - "has incompatible shape", - instruction->name().c_str()); + "has incompatible shapes: %s vs. %s", + instruction->name().c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); } ShapeIndex index; @@ -531,7 +528,6 @@ ProgramShape HloComputation::ComputeProgramShape() const { } *program_shape.mutable_result() = root_instruction_->shape(); - LayoutUtil::ClearLayout(&program_shape); return program_shape; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 061c59abe5e315917161ed737f89de53d71bb1b6..9d3f6e9a2c2efd97681a22b6b0f6d929afc553de 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -77,6 +77,14 @@ class HloComputation { return last_added_instruction_; } + Status ForEachInstruction( + const std::function& func) const { + for (const auto& instruction : instructions_) { + TF_RETURN_IF_ERROR(func(instruction.get())); + } + return Status::OK(); + } + private: const string name_; HloInstruction* last_added_instruction_; @@ -152,20 +160,12 @@ class HloComputation { // module: the module which will contain the computation. The newly created // computation is *not* added to the module, however. // proto: the proto to convert from. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used only when the instruction is a fusion instruction. - // fusion_instruction: if non-null then the newly created computation will - // be constructed as a fused computation with this instruction as its - // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation, - HloInstruction* fusion_instruction = nullptr); + const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. // @@ -240,7 +240,7 @@ class HloComputation { ShapeTree* copies_added = nullptr); // Computes and returns the ProgramShape of this computation (shape of - // parameters and result without layout). + // parameters and result with layout). ProgramShape ComputeProgramShape() const; // Return whether `*this` and `other` are functionally equivalent. @@ -334,6 +334,15 @@ class HloComputation { fusion_instruction_ = fusion_instruction; } + // The id of this computation should be unique within the module. + void SetUniqueId(int64 id) { + CHECK_EQ(unique_id_, -1); + CHECK_GE(id, 0); + unique_id_ = id; + } + + int64 unique_id() const { return unique_id_; } + private: explicit HloComputation( const string& name, int parameter_count, @@ -344,10 +353,6 @@ class HloComputation { HloInstruction* AddInstructionInternal( std::unique_ptr instruction); - // Helper for setting the parent of instructions that are added to this - // computation. - void Reparent(HloInstruction* instruction); - // Fuses HLOs in instructions_to_fuse into fusion_instruction. // // Pre-condition: fusion_instruction's opcode is kFusion. @@ -365,6 +370,7 @@ class HloComputation { std::vector CollectUnreachableRoots() const; string name_; + int64 unique_id_; HloInstruction* root_instruction_; // If this computation is a fusion computation, this field points to the diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 53450991b6fad5b9651d9d23b55c908e6b68e5dd..7aa38c6b79ed904bb4a518c4b7aaa1e079c27ea8 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -35,7 +35,10 @@ limitations under the License. namespace xla { StatusOr HloConstantFolding::Run(HloModule* module) { - auto evaluator = MakeUnique(); + // Limit the constant folding to 0 iterations to skip folding loops. This + // retains the behavior from before while loop support in HloEvaluator and may + // be revised. + auto evaluator = MakeUnique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); @@ -66,7 +69,8 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. - if (instruction->opcode() == HloOpcode::kBroadcast) { + if (instruction->opcode() == HloOpcode::kBroadcast || + instruction->opcode() == HloOpcode::kBroadcastDimOne) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 9cd5a1e2b71a7aa768e478289e8e4cc13030fcc3..ea4dd62fdb5bb3be40987d1a6ea96b3a58b0053b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -229,6 +230,10 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, @@ -331,6 +336,11 @@ Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleBroadcastDimOne( + const HloInstruction* broadcastDimOne) { + return Status::OK(); +} + Status HloCostAnalysis::HandlePad(const HloInstruction*) { return Status::OK(); } @@ -375,20 +385,101 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { } Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { - auto rhs_instruction = convolution->operand(1); + auto lhs = convolution->operand(0); + auto rhs = convolution->operand(1); + Window window = convolution->window(); + const auto& result_shape = convolution->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const auto& dnums = convolution->convolution_dimension_numbers(); - const int64 output_features = - convolution->shape().dimensions(dnums.output_feature_dimension()); - // For each output element, we do one fma per element in the kernel at some - // given output feature index. - const int64 fmas_per_output_element = - output_features > 0 - ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features - : 0; - const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); - current_properties_[kFlopsKey] = - output_elements * fmas_per_output_element * kFmaFlops; + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_feature_dim = dnums.input_feature_dimension(); + const int64 output_feature_dim = dnums.output_feature_dimension(); + const int64 input_feature = + ShapeUtil::GetDimension(lhs_shape, input_feature_dim); + const int64 output_feature = + ShapeUtil::GetDimension(result_shape, output_feature_dim); + const int64 batch = ShapeUtil::GetDimension(lhs_shape, input_batch_dim); + + DimensionVector kernel_limits; + DimensionVector output_limits; + DimensionVector input_limits; + if (window.dimensions().empty()) { + window = window_util::MakeWindow({1}); + kernel_limits.push_back(1); + output_limits.push_back(1); + input_limits.push_back(1); + } else { + for (int64 spatial_dimension = 0; + spatial_dimension < window.dimensions_size(); ++spatial_dimension) { + // Spatial dimension number for kernel (rhs). + const int64 kernel_spatial_dim = + dnums.kernel_spatial_dimensions(spatial_dimension); + const int64 kernel_limit = rhs_shape.dimensions(kernel_spatial_dim); + kernel_limits.push_back(kernel_limit); + + // Spatial dimension number for output. + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(spatial_dimension); + const int64 output_limit = result_shape.dimensions(output_spatial_dim); + output_limits.push_back(output_limit); + + // Spatial dimension number for input (lhs). + const int64 input_spatial_dim = + dnums.input_spatial_dimensions(spatial_dimension); + const int64 input_limit = lhs_shape.dimensions(input_spatial_dim); + input_limits.push_back(input_limit); + } + } + + DimensionVector valid_position_counts; + + // Loop over each spatial dimension. + for (int64 spatial_dimension = 0; + spatial_dimension < window.dimensions_size(); ++spatial_dimension) { + int64 valid_position_count = 0; + // Loop over each point in the kernel. + for (int64 kernel_idx = 0; kernel_idx < kernel_limits[spatial_dimension]; + ++kernel_idx) { + // Loop over each point in the output. + for (int64 output_idx = 0; output_idx < output_limits[spatial_dimension]; + ++output_idx) { + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(spatial_dimension); + const int64 undilated_index = output_idx * window_dim.stride() - + window_dim.padding_low() + + kernel_idx * window_dim.window_dilation(); + + // Calculate the actual lhs (input) index after dilation. Avoid the + // division as an optimization. + const int64 lhs_spatial_index = + window_dim.base_dilation() > 1 + ? undilated_index / window_dim.base_dilation() + : undilated_index; + + // Skip if the lhs (input) index is to be dilated. + if (undilated_index != lhs_spatial_index * window_dim.base_dilation()) { + continue; + } + + // Skip if input index is not in bound. + if (lhs_spatial_index < 0 || + lhs_spatial_index >= input_limits[spatial_dimension]) { + continue; + } + + valid_position_count += 1; + } + } + valid_position_counts.push_back(valid_position_count); + } + + const int64 fma_count = + input_feature * output_feature * batch * Product(valid_position_counts); + current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } @@ -529,6 +620,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { return Status::OK(); } +Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather does not issue any flops. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index e5783539e5436f09fa58bf7889118380ee90fea0..a9f6845747aa2081df936d388551bbc0b75b787b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; @@ -94,11 +95,13 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleSelectAndScatter(const HloInstruction* instruction) override; Status HandleBitcast(const HloInstruction* bitcast) override; Status HandleBroadcast(const HloInstruction* broadcast) override; + Status HandleBroadcastDimOne(const HloInstruction* broadcastDimOne) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; + Status HandleGather(const HloInstruction* gather) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 3b289c240a45e8f3df8156ed89e879da2132d01a..3d055b327ee920dac9c0904c69e1461206b31203 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -186,12 +186,14 @@ TEST_F(HloCostAnalysisTest, Map) { TEST_F(HloCostAnalysisTest, Convolution) { ComputationBuilder builder(client_, "convolution"); auto input = builder.Parameter( - 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, - /*x_dim=*/20}), + 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), "input"); auto kernel = builder.Parameter( - 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, - /*x_dim=*/3}), + 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), "kernel"); auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid); @@ -440,5 +442,32 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { + ComputationBuilder builder(client_, "BaseDilatedConvolution"); + auto input = builder.Parameter( + 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = builder.Parameter( + 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + + auto result = builder.ConvGeneralDilated( + input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, + /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, + ComputationBuilder::CreateDefaultConvDimensionNumbers(2)); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(), 1472); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..b186767ce792cd89ae77fe9a03b3a2ecf296b804 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -0,0 +1,277 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +using tensorflow::gtl::ArraySlice; +using tensorflow::strings::StrCat; + +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN(Shape binary_op_shape, + ShapeInference::InferBinaryOpShape(opcode, lhs, rhs)); + return computation->AddInstruction( + HloInstruction::CreateBinary(binary_op_shape, opcode, lhs, rhs)); +} + +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, padding_value->parent()); + TF_ASSIGN_OR_RETURN( + Shape pad_shape, + ShapeInference::InferPadShape(operand->shape(), padding_value->shape(), + padding_config)); + return computation->AddInstruction(HloInstruction::CreatePad( + pad_shape, operand, padding_value, padding_config)); +} + +StatusOr MakeSliceHlo(HloInstruction* operand, + ArraySlice start_indices, + ArraySlice limit_indices, + ArraySlice strides) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( + operand->shape(), start_indices, + limit_indices, strides)); + return computation->AddInstruction(HloInstruction::CreateSlice( + slice_shape, operand, start_indices, limit_indices, strides)); +} + +StatusOr MakeConvolveHlo( + HloInstruction* lhs, HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), + window, dimension_numbers)); + return computation->AddInstruction(HloInstruction::CreateConvolve( + convolve_shape, lhs, rhs, window, dimension_numbers)); +} + +StatusOr MakeTransposeHlo(HloInstruction* operand, + ArraySlice dimensions) { + HloComputation* computation = operand->parent(); + TF_ASSIGN_OR_RETURN( + Shape transpose_shape, + ShapeInference::InferTransposeShape(operand->shape(), dimensions)); + return computation->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, operand, dimensions)); +} + +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand) { + HloComputation* computation = operand->parent(); + return computation->AddInstruction( + HloInstruction::CreateReshape(result_shape, operand)); +} + +StatusOr MakeReshapeHlo( + ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + result_shape_dim_bounds); + return MakeReshapeHlo(new_shape, operand); +} + +StatusOr MakeDynamicSliceHlo(HloInstruction* operand, + HloInstruction* start_indices, + ArraySlice slice_sizes) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, start_indices->parent()); + TF_ASSIGN_OR_RETURN( + Shape dynamic_slice_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), slice_sizes)); + return computation->AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, operand, start_indices, slice_sizes)); +} + +StatusOr MakeDynamicUpdateSliceHlo( + HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices) { + HloComputation* computation = operand->parent(); + CHECK_EQ(computation, update->parent()); + CHECK_EQ(computation, start_indices->parent()); + TF_ASSIGN_OR_RETURN( + Shape dynamic_update_slice_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + return computation->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + dynamic_update_slice_shape, operand, update, start_indices)); +} + +StatusOr MakeBroadcastHlo( + HloInstruction* operand, ArraySlice broadcast_dimensions, + ArraySlice result_shape_bounds) { + HloComputation* computation = operand->parent(); + Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + result_shape_bounds); + + return computation->AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, operand, broadcast_dimensions)); +} + +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index) { + HloComputation* computation = operand->parent(); + + TF_ASSIGN_OR_RETURN( + Shape gte_shape, + ShapeInference::InferGetTupleElementShape(operand->shape(), index)); + return computation->AddInstruction( + HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); +} + +StatusOr MakeConcatHlo(ArraySlice operands, + int64 dimension) { + CHECK_GT(operands.size(), 0); + + HloComputation* computation = operands[0]->parent(); + CHECK(c_all_of(operands, [&](HloInstruction* instr) { + return instr->parent() == computation; + })); + + std::vector operand_shapes; + c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); + + TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( + operand_shapes, dimension)); + return computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); +} + +StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { + const Shape& operand_shape = operand->shape(); + CHECK_GE(operand_shape.dimensions_size(), n); + int64 new_shape_leading_bound = 1; + for (int64 i = 0; i < n; i++) { + new_shape_leading_bound *= operand_shape.dimensions(i); + } + + std::vector new_shape_dims; + new_shape_dims.reserve(operand_shape.dimensions_size() - n + 1); + new_shape_dims.push_back(new_shape_leading_bound); + + std::copy(operand_shape.dimensions().begin() + n, + operand_shape.dimensions().end(), + std::back_inserter(new_shape_dims)); + + Shape output_shape = + ShapeUtil::MakeShape(operand_shape.element_type(), new_shape_dims); + + return MakeReshapeHlo(output_shape, operand); +} + +StatusOr ExpandFirstDimIntoNDims( + HloInstruction* operand, ArraySlice expanded_dims) { + CHECK_GT(operand->shape().dimensions_size(), 0); + CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); + + std::vector expanded_shape_dim_bounds; + expanded_shape_dim_bounds.reserve(expanded_dims.size() + + operand->shape().dimensions_size() - 1); + c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + std::copy(operand->shape().dimensions().begin() + 1, + operand->shape().dimensions().end(), + std::back_inserter(expanded_shape_dim_bounds)); + Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + expanded_shape_dim_bounds); + return MakeReshapeHlo(new_shape, operand); +} + +StatusOr ElideDegenerateDims(HloInstruction* operand, + ArraySlice dims_to_elide) { + CHECK(c_is_sorted(dims_to_elide)); + + const Shape& input_shape = operand->shape(); + // First accumulate in reverse + std::vector new_shape_dim_bounds; + new_shape_dim_bounds.reserve(input_shape.dimensions_size() - + dims_to_elide.size()); + int64 dims_to_elide_idx = dims_to_elide.size() - 1; + for (int64 i = input_shape.dimensions_size() - 1; i >= 0; i--) { + if (dims_to_elide_idx >= 0 && i == dims_to_elide[dims_to_elide_idx]) { + CHECK_EQ(input_shape.dimensions(i), 1); + dims_to_elide_idx--; + } else { + new_shape_dim_bounds.push_back(input_shape.dimensions(i)); + } + } + + c_reverse(new_shape_dim_bounds); + Shape output_shape = + ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + +StatusOr PadVectorWithZeros(HloInstruction* operand, + int64 zeros_to_prepend, + int64 zeros_to_append) { + HloComputation* computation = operand->parent(); + CHECK_EQ(operand->shape().dimensions_size(), 1); + PaddingConfig padding_config; + PaddingConfig::PaddingConfigDimension padding_config_dim; + padding_config_dim.set_edge_padding_low(zeros_to_prepend); + padding_config_dim.set_edge_padding_high(zeros_to_append); + *padding_config.add_dimensions() = padding_config_dim; + + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(operand->shape().element_type())))); + return MakePadHlo(operand, zero, padding_config); +} + +StatusOr BroadcastZeros( + HloComputation* computation, PrimitiveType element_type, + ArraySlice broadcast_dimensions) { + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(element_type)))); + return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/broadcast_dimensions); +} + +StatusOr> CreateComputationWithSignature( + ArraySlice domain, const Shape& range, + tensorflow::StringPiece name) { + HloComputation::Builder b(name.ToString()); + int64 param_idx = 0; + for (const Shape* param_shape : domain) { + b.AddInstruction(HloInstruction::CreateParameter( + param_idx, *param_shape, StrCat("param.", param_idx))); + param_idx++; + } + + // We can't change the root type of a computation once it is created so create + // a dummy root instruction to give the computation the right root shape. In + // the future we may want to use a (recursive) broadcast here to avoid + // creating large constants. + b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(range))); + + return b.Build(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d99e32a737e6aaa2ff746cf6c00d4300cf62f4e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Some lightweight utilities intended to make HLO instruction creation more +// ergonomic. We don't have a complete set of helpers yet -- I expect we'll +// expand this interface as needed on an ad-hoc basis. + +// Creates a binary HLO instruction and adds it to the computation containing +// `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, + HloInstruction* rhs); + +// Creates a pad HLO instruction and adds it to the computation containing +// `operand` and `padding_value` (`operand` and `padding_value` must be in the +// same computation). +StatusOr MakePadHlo(HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + +// Creates a slice HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeSliceHlo( + HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + +// Creates a convolution HLO instruction and adds it to the computation +// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). +StatusOr MakeConvolveHlo( + HloInstruction* lhs, HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + +// Creates a transpose HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeTransposeHlo( + HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + +// Creates a reshape HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeReshapeHlo(const Shape& result_shape, + HloInstruction* operand); + +StatusOr MakeReshapeHlo( + tensorflow::gtl::ArraySlice result_shape_dim_bounds, + HloInstruction* operand); + +// Creates a dynamic-slice HLO instruction and adds it to the computation +// containing `operand` and `start_indices` (`operand` and `start_indices` must +// be in the same computation). +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + +// Creates a dynamic-update-slice HLO instruction and adds it to the computation +// containing `operand`, `update` and `start_indices` (`operand`, `update` and +// `start_indices` must be in the same computation). +StatusOr MakeDynamicUpdateSliceHlo( + HloInstruction* operand, HloInstruction* update, + HloInstruction* start_indices); + +// Creates a broadcast HLO instruction and adds it to the computation containing +// `operand`. +StatusOr MakeBroadcastHlo( + HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimensions, + tensorflow::gtl::ArraySlice result_shape_bounds); + +// Creates a GetTupleElement HLO instruction and adds it to the computation +// containing `operand`. +StatusOr MakeGetTupleElementHlo(HloInstruction* operand, + int64 index); + +// Creates a Concatenate HLO instruction and adds it to the computation +// containing `operands` (`operands` must be non-empty and every element must be +// contained in the same computation). +StatusOr MakeConcatHlo( + tensorflow::gtl::ArraySlice operands, int64 dimension); + +// ----------------------------------------------------------------------------- +// Some other miscellaneous helpers to generate common HLO patterns. All of +// these add all the instructions they generate into the computation containing +// their operand(s). + +// Collapses (via reshape) the first N (logical) dimensions of `operand` into a +// single leading dimension. `operand` must have rank > n. +// +// For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is +// the `operand` reshaped to [56,9]. +StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n); + +// Expands (via reshape) the first (logical) dimension of `operand` into a +// sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 +// and the number of elements in its first dimension must be equal to the +// product of `expanded_dims`. +// +// For instance if `operand` has shape f32[200,9,7] and expanded_dims is +// {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. +StatusOr ExpandFirstDimIntoNDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + +// Elides (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_elide` from `operand`. Every dimension in +// `dims_to_elide` must be a degenerate dimension. `dims_to_elide` must be +// sorted and not contain duplicates. +// +// For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide +// is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. +StatusOr ElideDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + +// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the +// front and `zeros_to_append` zeros in the back. +StatusOr PadVectorWithZeros(HloInstruction* operand, + int64 zeros_to_prepend, + int64 zeros_to_append); + +// Broadcasts a zero value of type `element_type` into a tensor with element +// type `element_type` and dimension bounds `broadcast_dimensions`. The +// broadcast instruction is emitted into `computation`. +StatusOr BroadcastZeros( + HloComputation* computation, PrimitiveType element_type, + tensorflow::gtl::ArraySlice broadcast_dimensions); + +// Creates a HLO computation that takes arguments of type `domain` and produces +// a value of type `range`. +StatusOr> CreateComputationWithSignature( + tensorflow::gtl::ArraySlice domain, const Shape& range, + tensorflow::StringPiece name); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 279edd4ba8772a9c576f76f554de8ec68631b953..cd7cbbdd71706fddb64855f631eb09de35da52e8 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -109,6 +109,11 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } + // Skip instructions which have side effects. + if (instruction->HasSideEffect()) { + continue; + } + // An instruction is considered to be equivalent to another only if they // share the exact same set of operands. So to find equivalent // instructions, we just search among instructions which share operand(0) @@ -118,7 +123,7 @@ StatusOr HloCSE::Run(HloModule* module) { tensorflow::gtl::InlinedVector equivalent_instructions; for (HloInstruction* user : operand->users()) { - if (user != instruction && + if (user != instruction && !user->HasSideEffect() && user->Identical(*instruction, eq_instructions, eq_computations, is_layout_sensitive_)) { equivalent_instructions.push_back(user); diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 3601a790c4428ee39c264b217a4b9a991ad8456c..df8853f34f6a72c52d1cde7332ada3809d2f3d96 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -414,8 +414,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { EXPECT_THAT(root, op::Add(rng1, rng2)); } -// TODO(b/28245743): Handle impure functions correctly in CSE. -TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { +TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { // Test that two calls to an impure function are not commoned. RNG // is the source of the impurity. @@ -458,14 +457,16 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Add(op::Map(), op::Map())); + VLOG(3) << "before: " << module->ToString(); + HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + VLOG(3) << "after: " << module->ToString(); EXPECT_EQ(4, computation->instruction_count()); root = computation->root_instruction(); - auto operand = root->operand(0)->operand(0); - EXPECT_THAT(operand, op::Map()); - EXPECT_THAT(root, op::Add(operand, operand)); + EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index ccbbe8f1966d59b4ab2904dcc6ea724aaf4a7603..0c37a8d75f38dabaad886cc9d4adce8ab29ddf18 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -38,12 +38,12 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, +HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(module)) {} + call_graph_(CallGraph::Build(&module)) {} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() { } string HloDataflowAnalysis::ToString() const { - string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); + string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { @@ -368,11 +368,11 @@ bool HloDataflowAnalysis::UpdateConditionalValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( conditional->false_computation()->root_instruction())}; - // A phi-node is not defined for a kConditional instruction even though it - // represents a join point. This is because the current approach is to define - // a phi-node only for kWhile to account for the dataflow through back-edges - // and deal with the ambiguity in other cases. - return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + if (ssa_form_) { + return Phi(conditional, inputs); + } else { + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); + } } bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { @@ -592,7 +592,7 @@ void HloDataflowAnalysis::Propagate() { } }; - for (HloComputation* computation : module_->computations()) { + for (HloComputation* computation : module_.computations()) { for (HloInstruction* instruction : computation->instructions()) { add_to_worklist(instruction); } @@ -686,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. @@ -787,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - HloModule* module, bool ssa_form, bool bitcast_defines_value) { - VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name(); - XLA_VLOG_LINES(2, module->ToString()); + const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); auto dataflow_analysis = WrapUnique( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); @@ -806,7 +806,7 @@ StatusOr> HloDataflowAnalysis::Run( // lookup is faster. std::vector> value_positions( dataflow_analysis->next_value_id_); - for (const HloComputation* computation : module->computations()) { + for (const HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : dataflow_analysis->GetInstructionValueSet(instruction)) { @@ -858,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const { // For each value in each value set, verify that the value set's position // appears in the value's positions(). - for (const auto& computation : module_->computations()) { + for (const auto& computation : module_.computations()) { for (const auto& instruction : computation->instructions()) { for (const auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 89d318188f0855c7924836a51cfe98d531e08cb4..7b8a74b096ff48733717e78ada5bb56a28caed72 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -60,7 +60,7 @@ class HloDataflowAnalysis { // a new HLO value in the analysis. If false then Bitcast forwards the // value of its operand. static StatusOr> Run( - HloModule* module, bool ssa_form = false, + const HloModule& module, bool ssa_form = false, bool bitcast_defines_value = false); // Returns true if 'instruction' defines an HLO value at the given shape index @@ -119,7 +119,7 @@ class HloDataflowAnalysis { string ToString() const; protected: - HloDataflowAnalysis(HloModule* module, bool ssa_form, + HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); // Returns a new HloValue defined at the given instruction and shape index. @@ -180,7 +180,7 @@ class HloDataflowAnalysis { // Verify various invariants of the dataflow analysis. Status Verify() const; - HloModule* const module_; + const HloModule& module_; const bool ssa_form_; const bool bitcast_defines_value_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e714b2567fd1b3eab607a19f0bb7e3288150dc64..07f69b8e1339fed636e4eb54791941b85e09fd17 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase, bool bitcast_defines_value = false) { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); analysis_ = - HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } @@ -1602,11 +1602,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), ElementsAre(HloUse{conditional, 2, {}})); - EXPECT_EQ(analysis.values().size(), 3); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), - analysis.GetValueDefinedAt(constant2))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 4); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); + } } TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { @@ -1713,11 +1719,17 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); - EXPECT_EQ(analysis.values().size(), 6); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT(HloValuesAt(conditional), - UnorderedElementsAre(analysis.GetValueDefinedAt(add), - analysis.GetValueDefinedAt(sub))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 7); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); + } } TEST_P(HloDataflowAnalysisTest, NestedConditionals) { @@ -1834,20 +1846,27 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) { EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), analysis.GetValueDefinedAt(constant2)); - EXPECT_EQ(analysis.values().size(), 9); - EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); - EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); - EXPECT_THAT( - HloValuesAt(inner_conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()))); - EXPECT_THAT( - HloValuesAt(conditional), - UnorderedElementsAre( - analysis.GetValueDefinedAt(computation1->root_instruction()), - analysis.GetValueDefinedAt(computation2->root_instruction()), - analysis.GetValueDefinedAt(computation3->root_instruction()))); + bool ssa_form = GetParam(); + if (ssa_form) { + EXPECT_EQ(analysis.values().size(), 11); + EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional)); + } else { + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); + } } INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 1e5f0f797a13fd7e7ce1cc934387a274a74153bc..fcd723af146e2227b8661b1a4993f1338f7de389 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -40,7 +40,7 @@ StatusOr HloDCE::Run(HloModule* module) { VLOG(2) << "Before dce:"; XLA_VLOG_LINES(2, module->ToString()); - for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* computation : module->MakeComputationPostOrder()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 8016b38d15330842c0e11f192587b6035a0dae01..52bc2c0448df080fa8224a2e28b66f13d8c9246b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -34,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -53,12 +51,22 @@ namespace xla { namespace { +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::FlatSet; +using tensorflow::gtl::optional; + template struct is_complex_t : public std::false_type {}; template <> struct is_complex_t : public std::true_type {}; +template +struct is_complex64_t : public std::false_type {}; + +template <> +struct is_complex64_t : public std::true_type {}; + template StatusOr> Compare(const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, @@ -101,11 +109,10 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, } auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -132,11 +139,10 @@ StatusOr> Compare( } auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -161,8 +167,8 @@ StatusOr> ElementWiseUnaryOpImpl( auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); @@ -174,7 +180,7 @@ StatusOr> ElementWiseUnaryOpImpl( // with the base index. void IterateThroughWindow( const Shape& window_shape, const Window& window, const Shape& base_shape, - const tensorflow::gtl::ArraySlice& window_count_index, + const ArraySlice& window_count_index, const std::function&)>& f) { const int64 rank = ShapeUtil::Rank(base_shape); DimensionVector window_index(rank); @@ -196,6 +202,25 @@ void IterateThroughWindow( } while (IndexUtil::BumpIndices(window_shape, &window_index)); } +// Creates a vector of multipliers which can be used to create a linear index +// into shape. +// +// Given the multidimensional index {i1, ..., iN} and +// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply +// +// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. +// +// This lets you calculate LI given the multidimensional indices in any order. +DimensionVector MakeDimMultipliers(const Shape& shape) { + DimensionVector v(ShapeUtil::Rank(shape)); + int64 scale = 1; + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + v[dim] = scale; + scale *= shape.dimensions(dim); + } + return v; +} + } // namespace template @@ -250,17 +275,37 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } return HandleAbs(abs); } @@ -308,13 +353,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { operand_to_broadcast.shape().dimensions(i)); } - return output->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); + return output->Populate([&](ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + }); } template < @@ -355,6 +399,22 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleBitcastConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + Status HandleExp(HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { @@ -588,14 +648,25 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::fmax(lhs, rhs); + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; })); return Status::OK(); } @@ -611,18 +682,30 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleMaximum(maximum); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> Status HandleMinimum(HloInstruction* minimum) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return std::fmin(lhs_el, rhs_el); + return std::min(lhs_el, rhs_el); })); return Status::OK(); } + template ::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -742,7 +825,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[shl], ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return lhs_elem << rhs_elem; + return IsShiftOutOfBounds(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); })); return Status::OK(); } @@ -767,8 +851,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[shr], ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - return static_cast(static_cast(lhs_elem) >> - rhs_elem); + SignedT lhs_signed = static_cast(lhs_elem); + if (IsShiftOutOfBounds(rhs_elem)) { + return lhs_signed < 0 ? static_cast(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } })); return Status::OK(); } @@ -795,7 +883,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[shr], ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { // If shift amount is greater than the number of bits, then return 0. - if (rhs_elem >= sizeof(UnsignedT) * CHAR_BIT) { + if (IsShiftOutOfBounds(rhs_elem)) { return static_cast(0); } return static_cast(static_cast(lhs_elem) >> @@ -822,7 +910,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmax(low, std::fmin(value, high)); + return std::fmin(high, std::fmax(value, low)); }; TF_ASSIGN_OR_RETURN( parent_->evaluated_[clamp], @@ -843,6 +931,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); CHECK(!ShapeUtil::IsTuple(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { @@ -873,8 +962,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -925,18 +1014,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - // Dimension number applicable for input (lhs). - const int64 input_batch_dim = dnums.input_batch_dimension(); - const int64 input_z_dim = dnums.input_feature_dimension(); - // Dimension number applicable for kernel (rhs). - const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); - const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - // Dimension number applicable for output. - const int64 output_batch_dim = dnums.output_batch_dimension(); - const int64 output_z_dim = dnums.output_feature_dimension(); - - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - std::vector window_dimension_sizes; for (auto i : dnums.kernel_spatial_dimensions()) { window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); @@ -945,25 +1022,43 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Shape& window_shape = ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); + DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); + DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { - ElementwiseT result_val = static_cast(0); + auto lhs_literal_data = lhs_literal.data(); + auto rhs_literal_data = rhs_literal.data(); + + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, + &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, + rhs_literal_data](ArraySlice out_index) { + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); - std::fill(lhs_index.begin(), lhs_index.end(), 0); - std::fill(rhs_index.begin(), rhs_index.end(), 0); - std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - lhs_index[input_batch_dim] = out_index[output_batch_dim]; - rhs_index[kernel_output_z_dim] = out_index[output_z_dim]; + ElementwiseT result_val = static_cast(0); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), + 0); // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { - lhs_index[input_z_dim] = iz; - rhs_index[kernel_input_z_dim] = iz; + int64 lhs_linear_index = 0; + lhs_linear_index += out_index[output_batch_dim] * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; + + int64 rhs_linear_index = 0; + rhs_linear_index += out_index[output_z_dim] * + rhs_dim_multipliers[kernel_output_z_dim]; + rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { @@ -988,29 +1083,32 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // Calculate the actual lhs (input) index after dilation. As an // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; if (window_dim.base_dilation() > 1) { - lhs_index[input_spatial_dim] = - undilated_index / window_dim.base_dilation(); + lhs_spatial_index = undilated_index / window_dim.base_dilation(); } else { - lhs_index[input_spatial_dim] = undilated_index; + lhs_spatial_index = undilated_index; } + lhs_linear_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - // Skip if input index is not in bound. - if (!(lhs_index[input_spatial_dim] >= 0 && - lhs_index[input_spatial_dim] < + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < lhs_shape.dimensions(input_spatial_dim))) { goto cnt; } - rhs_index[dnums.kernel_spatial_dimensions(ki)] = - window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]; + rhs_linear_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; } result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); + static_cast(lhs_literal_data[lhs_linear_index]) * + static_cast(rhs_literal_data[rhs_linear_index]); } cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); @@ -1019,7 +1117,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { }; auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate(func)); + TF_RETURN_IF_ERROR(result->PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1071,9 +1169,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } std::vector rhs_non_batch_non_contracting_dims; - tensorflow::gtl::FlatSet batch_dims_set( - dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); + FlatSet batch_dims_set(dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); for (int64 i = 0; i < rhs_rank; i++) { if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { rhs_non_batch_non_contracting_dims.push_back(i); @@ -1085,8 +1182,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector lhs_index(lhs_rank); DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice result_index) { ElementwiseT result_val = static_cast(0); // Find the corresponding non-contracting indices for lhs and rhs. @@ -1180,9 +1277,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); auto result = Literal::CreateFromShape(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); + [&scalar](ArraySlice multi_index) { return scalar; })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); @@ -1195,7 +1290,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](const std::vector& input_index) { + auto func = [&](ArraySlice input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1345,9 +1440,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(map->shape()); - HloEvaluator embedded_evaluator; - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { std::vector> arg_literals; arg_literals.reserve(operands.size()); @@ -1437,7 +1532,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == ShapeUtil::Rank(arg->shape()) - dimensions.size()); @@ -1469,21 +1564,19 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { arg_dim_counts[dim] = arg_dimensions[dim]; } - // Create mapping from result index to arg index. - const int64 result_rank = ShapeUtil::Rank(result->shape()); - int64 result_dim = 0; - std::vector result_to_arg_index(result_rank); + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; for (int64 i = 0; i < arg_dimensions.size(); ++i) { if (arg_dim_steps[i] == 0) { - result_to_arg_index[result_dim] = i; - ++result_dim; + result_to_arg_index.push_back(i); } } - HloEvaluator embedded_evaluator; + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { ReturnT result_val = init_scalar; std::vector base(arg_dimensions.size()); @@ -1491,31 +1584,43 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { base[result_to_arg_index[i]] = multi_index[i]; } - auto func = [&](const std::vector& input_index) { + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](ArraySlice input_index) { + computed_result += arg_literal.Get(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = [&](ArraySlice input_index) { auto curr_val = arg_literal.Get(input_index); // Evaluate computation with specified literal operands. auto curr_val_literal = Literal::CreateR0(curr_val); auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {curr_val_literal.get(), - result_val_literal.get()}; + std::vector args = {result_val_literal.get(), + curr_val_literal.get()}; std::unique_ptr computed_result = embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on + // Clear visit states so that we can use the evaluator again on // the same computation. embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. result_val = computed_result->Get({}); - return true; }; - + // Computes one element of the result, reducing all dimensions that + // contribute to that element. ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); - return result_val; })); @@ -1523,6 +1628,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { auto operand = select_and_scatter->operand(0); auto source = select_and_scatter->operand(1); @@ -1537,9 +1656,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); + [&](ArraySlice output_index) { return init_scalar; })); std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1556,7 +1673,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { int64 rank = ShapeUtil::Rank(operand_literal.shape()); - HloEvaluator embedded_evaluator; + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); DimensionVector source_index(rank); std::fill(source_index.begin(), source_index.end(), 0); @@ -1572,8 +1689,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // 2. Using the selected index, scatter value from `source` to result. We // do this by iterating through the window, and compare each index with // the selected index. - tensorflow::gtl::optional selected_val; - tensorflow::gtl::optional> selected_index; + optional selected_val; + optional> selected_index; IterateThroughWindow( window_shape, window, operand_literal.shape(), source_index, @@ -1588,11 +1705,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Literal::CreateR0(*selected_val); const std::vector args = { - curr_val_literal.get(), selected_val_literal.get()}; + selected_val_literal.get(), curr_val_literal.get()}; std::unique_ptr computed_result = embedded_evaluator.Evaluate(*select, args) .ConsumeValueOrDie(); - bool selected = computed_result->Get({}); + bool selected = !computed_result->Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1667,10 +1784,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector window_index(window.dimensions_size()); DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - HloEvaluator embedded_evaluator; + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1687,7 +1804,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const auto result_val_literal = Literal::CreateR0(result_val); const std::vector args = { - curr_val_literal.get(), result_val_literal.get()}; + result_val_literal.get(), curr_val_literal.get()}; std::unique_ptr computed_result = embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); @@ -1720,7 +1837,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const int64 rank = ShapeUtil::Rank(operand->shape()); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](ArraySlice out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { operand_index[i] = @@ -1901,8 +2018,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::vector operand_indices(start.size()); auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); // Mod is only used here to be consistent with the existing @@ -1922,17 +2039,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> DynamicUpdateSlice( const Literal& operand_literal, const Literal& update_literal, const Literal& start_indices_literal) { - auto start_indices_typed = start_indices_literal.data(); - const std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - auto result = operand_literal.CloneToUnique(); - std::vector result_index(ShapeUtil::Rank(result->shape()), 0); + auto start_indices_typed = start_indices_literal.data(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector start(rank, 0); + for (int64 i = 0; i < rank; ++i) { + // All other implementations currently wrap-around the index, so this + // should do so as well. + start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); + start[i] += (start[i] < 0) * result->shape().dimensions(i); + } + std::vector result_index(rank, 0); - auto func = [&](const std::vector& update_index) { + auto func = [&](ArraySlice update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - + // Same as above, wrap-around only to match other implementations' + // semantics. + std::transform(result_index.begin(), result_index.end(), + result->shape().dimensions().begin(), result_index.begin(), + std::modulus()); result->Set(result_index, update_literal.Get(update_index)); return true; @@ -1985,8 +2111,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2023,8 +2149,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](ArraySlice multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); @@ -2033,20 +2159,31 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } + template + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + HloEvaluator* parent_; }; // class HloEvaluator::TypedVisitor -HloEvaluator::HloEvaluator() { +HloEvaluator::HloEvaluator(int64 max_loop_iterations) + : max_loop_iterations_(max_loop_iterations) { typed_visitors_[PRED] = MakeUnique>(this); typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: U16."); + return Unimplemented( + "HloEvaluator::TypedVisitor: unhandled primitive type: U16."); }); typed_visitors_[U32] = MakeUnique>(this); typed_visitors_[U64] = MakeUnique>(this); typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: S16."); + return Unimplemented( + "HloEvaluator::TypedVisitor: unhandled primitive type: S16."); }); typed_visitors_[S32] = MakeUnique>(this); typed_visitors_[S64] = MakeUnique>(this); @@ -2060,18 +2197,20 @@ HloEvaluator::HloEvaluator() { // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = MakeUnique>(this); + typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); + return Unimplemented( + "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE."); + return Unimplemented( + "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE."); }); } template StatusOr> HloEvaluator::Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals) { + const HloModule& module, ArraySlice arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -2088,8 +2227,8 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals) { + const HloComputation& computation, ArraySlice arg_literals) { + CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -2105,8 +2244,7 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals) { + HloInstruction* instruction, ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); @@ -2231,8 +2369,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + ArraySlice operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -2404,6 +2541,349 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output +// gather dimensions while keeping the rest of the output dimensions clamped to +// 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( + const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { + int64 output_rank = output_shape.dimensions_size(); + std::vector index_base(output_rank, 0); + std::vector index_count; + index_count.reserve(output_rank); + for (int64 i = 0; i < output_rank; i++) { + bool is_output_gather_dim = + !c_binary_search(dim_numbers.output_window_dims(), i); + index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) + : 1); + } + + return {std::move(index_base), std::move(index_count), + std::vector(output_rank, 1)}; +} + +// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( + int64 output_rank, ArraySlice window_bounds, + const GatherDimensionNumbers& dim_numbers) { + std::vector index_base(output_rank, 0); + std::vector index_count(output_rank, 1); + int64 window_bounds_idx = 0; + for (int64 i = 0; i < output_rank; i++) { + bool is_output_window_dim = + c_binary_search(dim_numbers.output_window_dims(), i); + if (is_output_window_dim) { + while (c_binary_search(dim_numbers.elided_window_dims(), + window_bounds_idx)) { + window_bounds_idx++; + } + index_count[i] = window_bounds[window_bounds_idx++]; + } + } + + return {std::move(index_base), std::move(index_count), + std::vector(output_rank, 1)}; +} + +// This functor computes the contribution of gather_indices to an input index +// corresponding to an output index. That is, given an output index I, it picks +// out the gather output indices in I and uses them to look up a gather index, +// G, from the gather indices tensor, and expands G into the input space +// according to gather_dims_to_operand_dims. +class OutputGatherIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit OutputGatherIndexToInputIndex( + const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, + const Shape& output_shape, const Literal* gather_indices) + : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + for (int64 i = 0; i < output_shape.dimensions_size(); i++) { + output_dim_is_gather_dims_.push_back( + !c_binary_search(dim_numbers_.output_window_dims(), i)); + } + + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + int64 index_of_input_dim_in_index_vector = + std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), + c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + if (index_of_input_dim_in_index_vector == + dim_numbers_.gather_dims_to_operand_dims_size()) { + input_dim_value_to_index_vector_.push_back(-1); + } else { + input_dim_value_to_index_vector_.push_back( + index_of_input_dim_in_index_vector); + } + } + + index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + input_index_.resize(input_shape.dimensions_size()); + int64 index_vector_size = + gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + index_vector_.resize(index_vector_size); + } + + // Returns the contribution of gather_indices to the input index corresponding + // to output_index. See gather_inner_loop_body. + // + // This is conceptually a stateless transformation from output_index to the + // gather input index, but: + // + // - Instead of allocating memory to represent the gather input index on + // every invocation we reuse the same storage for the result + // (input_index_), mutating it in place. + // - Instead of allocating buffers for temporary values like + // index_vector_index_ and index_vector on every invocation, we reuse the + // same storage for all invocations. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()(ArraySlice output_index) { + PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); + TF_RETURN_IF_ERROR(FetchIndexVector()); + PropagateIndexVectorToInputIndex(); + return ArraySlice(input_index_); + } + + private: + // Propagates the gather index dimensions from the output index into + // index_vector_index_ by mutating index_vector_index_ in place. Does not + // update the dim_numbers.index_vector_dim() dimension -- that's the dimension + // we iterate over in FetchIndexVector. + void PropagateOutputIndexGatherDimsToIndexVectorIndex( + ArraySlice output_index) { + int64 index_vector_index_i = 0; + for (int64 i = 0, e = output_index.size(); i < e; i++) { + if (!output_dim_is_gather_dims_[i]) { + continue; + } + + if (index_vector_index_i == dim_numbers_.index_vector_dim()) { + index_vector_index_i++; + } + + index_vector_index_[index_vector_index_i++] = output_index[i]; + } + } + + // Populates index_vector_ by iterating over gather_indices_ according to + // index_vector_index_. + Status FetchIndexVector() { + int64 index_vector_dim = dim_numbers_.index_vector_dim(); + for (int64 i = 0, e = index_vector_.size(); i < e; i++) { + index_vector_index_[index_vector_dim] = i; + TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( + index_vector_index_)); + } + return Status::OK(); + } + + // Populates input_index_. + void PropagateIndexVectorToInputIndex() { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_index_vector_[i] != -1) { + input_index_[i] = index_vector_[input_dim_value_to_index_vector_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of + // the input index from the index vector. See + // PropagateIndexVectorToInputIndex. + std::vector input_dim_value_to_index_vector_; + + // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // dimension. + std::vector output_dim_is_gather_dims_; + + // The buffer into which we construct an index into gather_indices_ to fetch + // the index vector. + std::vector index_vector_index_; + + // The index vector fetched from gather_indices_. + std::vector index_vector_; + + // The result computed by this functor. operator() returns an ArraySlice into + // this vector. + std::vector input_index_; + + const GatherDimensionNumbers& dim_numbers_; + const Literal& gather_indices_; +}; + +// This functor computes the contribution of the window indices in an output +// index to an input index. That is, given an output index I it picks out the +// output window indices in I and expands it into a window index into the input +// shape. +class OutputWindowIndexToInputIndex { + public: + // The constructor does some setup work that is amortized across all + // iterations. + explicit OutputWindowIndexToInputIndex( + const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, + const Shape& output_shape) { + std::vector window_index_to_output_index; + int64 output_index_count = 0; + for (int64 i = 0; i < output_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.output_window_dims(), i)) { + window_index_to_output_index.push_back(output_index_count++); + } else { + output_index_count++; + } + } + + int64 window_dim_count = 0; + for (int64 i = 0; i < input_shape.dimensions_size(); i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + input_dim_value_to_output_index_.push_back(-1); + } else { + input_dim_value_to_output_index_.push_back( + window_index_to_output_index[window_dim_count++]); + } + } + + input_index_.resize(input_shape.dimensions_size()); + } + + // Returns the contribution of the window indices to the input index + // corresponding to output_index. See gather_inner_loop_body. + // + // This is conceptually a stateless transformation from output_index to the + // window input index, but instead of allocating memory to represent the + // gather input index on every invocation we reuse the same storage for the + // result (input_index_), mutating it in place. + // + // This returns an arrayslice into memory owned by the class. + StatusOr> operator()(ArraySlice output_index) { + PropagateOutputIndexWindowDimsToInputIndex(output_index); + return ArraySlice(input_index_); + } + + private: + // Propagates window dimensions from the output index to input_index_ by + // mutating input_index_ in place. + void PropagateOutputIndexWindowDimsToInputIndex( + ArraySlice output_index) { + for (int64 i = 0, e = input_index_.size(); i < e; i++) { + if (input_dim_value_to_output_index_[i] != -1) { + input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; + } + + // If input_dim_value_to_index_vector_[i] == -1 then input_index_[i] + // remains 0, as set by the constructor. + } + } + + // input_dim_value_to_index_vector_[i] tells us how to compute dimension i of + // the input index from the output index. See + // PropagateOutputIndexToInputIndex. + std::vector input_dim_value_to_output_index_; + + // The result computed by this functor. operator() returns an ArraySlice into + // this vector. + std::vector input_index_; +}; + +// Rehapes the gather indices input to have a trailing degenerate `1` dimension +// if necessary. Hands over the ownership of the newly created literal (if +// there is one) to `reshaped_gather_indices`. +static StatusOr> ReshapedGatherIndices( + int64 index_vector_dim, const Literal& gather_indices, + std::unique_ptr* reshaped_gather_indices) { + if (gather_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(gather_indices); + } + + std::vector new_shape(gather_indices.shape().dimensions().begin(), + gather_indices.shape().dimensions().end()); + new_shape.push_back(1); + TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, + gather_indices.Reshape(new_shape)); + return std::cref(**reshaped_gather_indices); +} + +Status HloEvaluator::HandleGather(HloInstruction* gather) { + std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + const Shape& shape = gather->shape(); + const GatherDimensionNumbers& dim_numbers = + gather->gather_dimension_numbers(); + const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); + std::unique_ptr reshaped_gather_indices; + TF_ASSIGN_OR_RETURN( + const Literal& gather_indices, + ReshapedGatherIndices(dim_numbers.index_vector_dim(), + GetEvaluatedLiteralFor(gather->operand(1)), + &reshaped_gather_indices)); + + // We iterate over the gather dimensions in the output shape in an outer loop + // nest, and iterate over the window dimensions in the output shape in an + // inner loop nest. + + ShapeUtil::IndexIterationSpace gather_indices_iteration_space = + IterationSpaceForOutputGatherIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace window_indices_iteration_space = + IterationSpaceForOutputWindowIndices( + shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + + // Scratch buffers that hold an index in the output shape and the + // corresponding index in the input shape. + std::vector input_index(operand.shape().dimensions_size()); + std::vector output_index(gather->shape().dimensions_size()); + + OutputGatherIndexToInputIndex output_gather_index_to_input_index( + &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), + /*output_shape=*/shape, &gather_indices); + OutputWindowIndexToInputIndex output_window_index_to_input_index( + gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), + /*output_shape=*/shape); + + const Shape& operand_shape = operand.shape(); + + auto gather_inner_loop_body = + [&](ArraySlice output_window_index, + ArraySlice input_gather_index, + ArraySlice output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + ArraySlice input_window_index, + output_window_index_to_input_index(output_window_index)); + for (int i = 0, e = output_index.size(); i < e; i++) { + output_index[i] = output_gather_index[i] + output_window_index[i]; + DCHECK_LT(output_index[i], shape.dimensions(i)); + } + for (int i = 0, e = input_index.size(); i < e; i++) { + // TODO(b/74360564): We should implement whatever out of bounds behavior + // we decide for dynamic-slice here as well. + input_index[i] = (input_gather_index[i] + input_window_index[i]) % + operand_shape.dimensions(i); + if (input_index[i] < 0) { + input_index[i] += operand_shape.dimensions(i); + } + } + TF_RETURN_IF_ERROR( + result->CopyElementFrom(operand, input_index, output_index)); + return true; + }; + + auto gather_outer_loop_body = + [&](ArraySlice output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN( + ArraySlice input_gather_index, + output_gather_index_to_input_index(output_gather_index)); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + shape, window_indices_iteration_space, + std::bind(gather_inner_loop_body, std::placeholders::_1, + input_gather_index, output_gather_index))); + return true; + }; + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + shape, gather_indices_iteration_space, gather_outer_loop_body)); + evaluated_[gather] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -2434,6 +2914,135 @@ Status HloEvaluator::HandleCopy(HloInstruction* copy) { return Status::OK(); } +Status HloEvaluator::HandleCall(HloInstruction* call) { + auto* computation = call->to_apply(); + auto operands = call->operands(); + + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[call] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleFusion(HloInstruction* fusion) { + // Attach cloned computation to an empty HLO module so the existing ones are + // not modified. + HloModule empty_hlo_module("EmptyModuleForFusion"); + auto cloned_fused_computation = + fusion->fused_instructions_computation()->Clone( + /*suffix=*/"clone_with_layout", &empty_hlo_module); + for (auto* instruction : cloned_fused_computation->instructions()) { + LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); + } + auto readded_computation = + empty_hlo_module.AddEntryComputation(std::move(cloned_fused_computation)); + + auto operands = fusion->operands(); + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator + .Evaluate(*readded_computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[fusion] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleConditional(HloInstruction* conditional) { + const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); + const auto& true_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(1)); + const auto& false_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(2)); + + auto* true_computation = conditional->true_computation(); + auto* false_computation = conditional->false_computation(); + + auto result = Literal::CreateFromShape(conditional->shape()); + HloEvaluator embedded_evaluator; + if (pred.Get({})) { + result = embedded_evaluator + .Evaluate(*true_computation, + {&true_computation_arg}) + .ConsumeValueOrDie(); + } else { + result = embedded_evaluator + .Evaluate(*false_computation, + {&false_computation_arg}) + .ConsumeValueOrDie(); + } + + evaluated_[conditional] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleSelect(HloInstruction* select) { + const auto& pred = GetEvaluatedLiteralFor(select->operand(0)); + const auto& on_true = GetEvaluatedLiteralFor(select->operand(1)); + const auto& on_false = GetEvaluatedLiteralFor(select->operand(2)); + + // If predicate is of scalar type, no element-wise selection would be needed. + // This would also handle output array of tuple types as the DefaultAction + // would go through the TypedVisitor which doesn't handle tuples. + if (ShapeUtil::IsScalar(pred.shape())) { + if (pred.Get({})) { + evaluated_[select] = on_true.CloneToUnique(); + } else { + evaluated_[select] = on_false.CloneToUnique(); + } + return Status::OK(); + } + + return DefaultAction(select); +} + +Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { + HloComputation* cond_comp = while_hlo->while_condition(); + HloComputation* body_comp = while_hlo->while_body(); + // Initialize the loop carried valued with the input to the While instruction. + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + bool keep_going = true; + int64 iteration_count = 0; + HloEvaluator cond_evaluator(max_loop_iterations_); + HloEvaluator loop_body_evaluator(max_loop_iterations_); + while (keep_going) { + if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { + return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", + while_hlo->name().c_str(), max_loop_iterations_); + } + TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( + *cond_comp, {lcv.get()})); + keep_going = cond_val->GetFirstElement(); + if (keep_going) { + TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( + *body_comp, {lcv.get()})); + VLOG(3) << "Loop iteration result: " << body_val->ToString(); + lcv = std::move(body_val); + cond_evaluator.ResetVisitStates(); + loop_body_evaluator.ResetVisitStates(); + } + } + evaluated_[while_hlo] = std::move(lcv); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); @@ -2447,28 +3056,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(const HloModule& module, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( - const HloModule& module, - tensorflow::gtl::ArraySlice> arg_literals); + const HloModule& module, ArraySlice> arg_literals); -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(const HloComputation& computation, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( const HloComputation& computation, - tensorflow::gtl::ArraySlice> arg_literals); + ArraySlice> arg_literals); -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate(HloInstruction* instruction, + ArraySlice arg_literals); template StatusOr> HloEvaluator::Evaluate>( HloInstruction* instruction, - tensorflow::gtl::ArraySlice> arg_literals); + ArraySlice> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 3b2b697e492a78a06a4e5ae6bf056ff8676f2ff5..c0dcee0c3e382f74de72a2b89f39e06f042e2b80 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -36,7 +36,10 @@ namespace xla { // This class is not thread-safe. class HloEvaluator : public DfsHloVisitorWithDefault { public: - HloEvaluator(); + // Only evaluate up to max_loop_iterations per while-loop execution if + // specified. + explicit HloEvaluator(int64 max_loop_iterations = -1); + // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. // Precondition: The indices of arg_literals correspond to the parameter @@ -149,10 +152,22 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; + Status HandleGather(HloInstruction* gather) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + + Status HandleCall(HloInstruction* call) override; + + Status HandleFusion(HloInstruction* fusion) override; + + Status HandleWhile(HloInstruction* while_hlo) override; + + Status HandleSelect(HloInstruction* select) override; + private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be @@ -190,6 +205,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Must be cleared for each evaluation. std::vector arg_literals_; + // Max loop iterations to execute with no maximum if negative. + int64 max_loop_iterations_; + TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 97765d65909cee192f65069777f8f195081603b2..dd14dd38537a83d0ee16cff9e3c22a38f544e208 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -1205,6 +1206,80 @@ TEST_P(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; + +// Tests that Reduce doesn't lose precision when adding many numbers (because +// it accumulates its result in a double). +TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { + HloComputation::Builder b(TestName()); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(v))); + HloInstruction* init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module().AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + std::unique_ptr result = + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, *result); +} + +// Reducing many numbers should be fast because it doesn't create +// intermediate Literals; the microbenchmark should finish in < 1 msec. +void BM_ReducePrecisely(int num_iters) { + tensorflow::testing::StopTiming(); + HloComputation::Builder b("BM_ReducePrecisely"); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + + constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 + std::vector v(kNumElements, 1.0f); + HloInstruction* arg_instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(v))); + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + HloInstruction* reduce_instruction = b.AddInstruction( + HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value, + /*dimensions_to_reduce=*/{0}, add_func)); + module.AddEntryComputation(b.Build()); + + HloEvaluator hlo_eval; + tensorflow::testing::StartTiming(); + hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + tensorflow::testing::StopTiming(); +} + +BENCHMARK(BM_ReducePrecisely); + TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); @@ -1729,6 +1804,207 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { *result.ValueOrDie()); } +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR3( + {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { + const char* hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, + EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { + const char* hlo_text = R"( +HloModule DynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,0] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 0} +} +)"; + ParseAndVerifyModule(hlo_text); + std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + LiteralTestUtil::ExpectEqual( + *Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()})); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index f0df93b61d29c1535d8a89fbd65e669de5b43729..c3ccbf0f0c75b569b49652807dea52faebdccc31 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -111,8 +111,8 @@ HloExecutionProfile::HloExecutionProfile( : hlo_profile_printer_data_(*hlo_profile_printer_data), hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( - /*count*/ hlo_profile_index_map_.total_count(), - /*value*/ 0) {} + /*count=*/hlo_profile_index_map_.total_count(), + /*value=*/0) {} void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 44fcd36370dcd0cf77601aa1cd2b92810947bd5f..c35783c456c63b9a651d1221cf9a3d70af38ba66 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -157,52 +157,60 @@ enum ColorScheme { kDashedBorder, }; +// Graphviz attributes/colors that make up a color scheme. +struct NodeColors { + const char* style; + const char* fill_color; + const char* stroke_color; + const char* font_color; +}; + +NodeColors NodeColorsForScheme(ColorScheme color) { + switch (color) { + case kBlue: + return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"}; + case kBrown: + return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"}; + case kDarkBlue: + return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; + case kDarkGreen: + return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkRed: + return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; + case kGray: + return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"}; + case kGreen: + return NodeColors{"filled", "#c8e6c9", "#97b498", "black"}; + case kOrange: + return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"}; + case kPurple: + return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"}; + case kRed: + return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"}; + case kWhite: + return NodeColors{"filled", "white", "black", "black"}; + case kYellow: + return NodeColors{"filled", "#fff9c4", "#cbc693", "black"}; + case kDashedBorder: + // "filled,dashed" looks the same as "dashed", since we have a white + // background. But we use "filled,dashed" so that when you hover over + // any part of the node (not just the text inside the node), our css + // :hover rule is triggered. + return NodeColors{"filled,dashed", "white", "#757575", "#757575"}; + } +} + // Given a ColorScheme, returns an attribute string for a node of that color. // Sets the node's style and fill/stroke/text colors. // // Colors are from https://material.io/color. string NodeColorAttributes(ColorScheme color) { - using std::make_tuple; - - const char *style, *fill_color, *stroke_color, *font_color; - std::tie(style, fill_color, stroke_color, font_color) = [color] { - switch (color) { - case kBlue: - return make_tuple("filled", "#bbdefb", "#8aacc8", "black"); - case kBrown: - return make_tuple("filled", "#bcaaa4", "#8c7b75", "black"); - case kDarkBlue: - return make_tuple("filled", "#1565c0", "#003c8f", "white"); - case kDarkGreen: - return make_tuple("filled", "#2e7d32", "#005005", "white"); - case kDarkRed: - return make_tuple("filled", "#b71c1c", "#7f0000", "white"); - case kGray: - return make_tuple("filled", "#cfd8dc", "#9ea7aa", "black"); - case kGreen: - return make_tuple("filled", "#c8e6c9", "#97b498", "black"); - case kOrange: - return make_tuple("filled", "#ffe0b2", "#cbae82", "black"); - case kPurple: - return make_tuple("filled", "#e1bee7", "#af8eb5", "black"); - case kRed: - return make_tuple("filled", "#ffcdd2", "#cb9ca1", "black"); - case kWhite: - return make_tuple("filled", "white", "black", "black"); - case kYellow: - return make_tuple("filled", "#fff9c4", "#cbc693", "black"); - case kDashedBorder: - // "filled,dashed" looks the same as "dashed", since we have a white - // background. But we use "filled,dashed" so that when you hover over - // any part of the node (not just the text inside the node), our css - // :hover rule is triggered. - return make_tuple("filled,dashed", "white", "#757575", "#757575"); - } - }(); + NodeColors node_colors = NodeColorsForScheme(color); return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", style, - font_color, stroke_color, fill_color); + R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, node_colors.stroke_color, + node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a @@ -604,11 +612,21 @@ tooltip = " "; StrAppend(&subcomp_label, "
", extra_info); } - // Subcomputation's fill/stroke color is light/dark red/gray, depending on - // whether or not the subcomputation's fusion node is highlighted. bool highlight = filter_.Highlight(parent_instr); - const char* fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; - const char* strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + const char* fillcolor; + const char* strokecolor; + if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) { + // Use the sharding color, if the node isn't highlighted. + NodeColors node_colors = + NodeColorsForScheme(GetInstructionColor(parent_instr)); + fillcolor = node_colors.fill_color; + strokecolor = node_colors.stroke_color; + } else { + // Subcomputation's fill/stroke color is light/dark red/gray, depending on + // whether or not the subcomputation's fusion node is highlighted. + fillcolor = highlight ? "#ffcdd2" : "#f5f5f5"; + strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; + } style = Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", fillcolor, strokecolor); @@ -782,6 +800,14 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( auto stringify_constant = [](const HloInstruction* constant) { const auto& shape = constant->shape(); + // If the shape has a dimension of size zero, print it as e.g. + // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), + // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which + // is just noise. + if (ShapeUtil::HasZeroElements(shape)) { + return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + } + // Print the literal value of constants with <= K elements. optional elem_count; if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { @@ -797,7 +823,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::StringPiece(constant->name()).starts_with("constant")) { + if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); @@ -930,6 +956,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: + case HloOpcode::kBroadcastDimOne: // De-emphasize nodes which broadcast a scalar within a fusion node -- // these are essentially free. if (instr->IsFused() && @@ -940,6 +967,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConcatenate: case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: @@ -988,6 +1016,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: @@ -1013,8 +1042,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::StringPiece(instr->name()) - .starts_with(HloOpcodeString(instr->opcode()))) { + if (tensorflow::str_util::StartsWith(instr->name(), + HloOpcodeString(instr->opcode()))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 0981f1f4fe57751d5b7059b4b08099385369e4b9..56cb241087cf31084df76c25ead89d477cd38f0f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -51,24 +52,22 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation) { + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const string& operand_name : proto.operand_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) - << "No instruction named " << operand_name; - instruction->AppendOperand(instruction_map.at(operand_name)); - } - for (const string& predecessor_name : proto.control_predecessor_names()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) - << "No instruction named " << predecessor_name; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) ->AddControlDependencyTo(instruction.get())); } @@ -76,26 +75,36 @@ StatusOr> HloInstruction::CreateFromProto( // HloInstructionProto and do not appear as an HloComputationProto within the // HloModuleProto. if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(proto.has_fused_instructions_computation()); TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + fused_computation->SetFusionInstruction(instruction.get()); + instruction->called_computations_.push_back(fused_computation); } else { - for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_name)) - << "No computation named " << computation_name; + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; instruction->called_computations_.push_back( - computation_map.at(computation_name)); + computation_map.at(computation_id)); } } + if (instruction->opcode() == HloOpcode::kTrace) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "Trace instruction should have 1 operand but sees " + << instruction->operands().size(); + instruction->mutable_operand(0)->set_tracing(instruction.get()); + } + TF_RET_CHECK(!proto.name().empty()); instruction->name_ = proto.name(); @@ -150,6 +159,23 @@ StatusOr> HloInstruction::CreateFromProto( instruction->fft_length_.push_back(fft_len); } + if (proto.has_sharding()) { + TF_ASSIGN_OR_RETURN(const auto& sharding, + HloSharding::FromProto(proto.sharding())); + instruction->set_sharding(sharding); + } + + if (proto.has_gather_dimension_numbers()) { + instruction->gather_dimension_numbers_ = + MakeUnique(proto.gather_dimension_numbers()); + } + for (int64 bound : proto.gather_window_bounds()) { + instruction->gather_window_bounds_.push_back(bound); + } + + instruction->channel_name_ = proto.channel_name(); + instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); + return std::move(instruction); } @@ -168,6 +194,7 @@ StatusOr> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_ = Literal::CreateR1U8(tag); + operand->set_tracing(instruction.get()); return instruction; } @@ -182,6 +209,7 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); auto instruction = WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); instruction->tuple_index_ = index; @@ -672,6 +700,15 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBroadcastDimOne(const Shape& shape, + HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBroadcastDimOne, shape)); + instruction->AppendOperand(operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, @@ -801,6 +838,32 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { return instruction; } +void HloInstruction::SetupDerivedInstruction( + HloInstruction* derived_instruction) const { + if (sharding_ != nullptr) { + derived_instruction->set_sharding(*sharding_); + } else { + derived_instruction->clear_sharding(); + } + derived_instruction->set_metadata(metadata_); +} + +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + CHECK_EQ(opcode(), HloOpcode::kFusion); + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -993,13 +1056,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Clone's operand was not already an operand of the fusion // instruction. Add it as an operand and add a corresponding fused // parameter instruction. - int64 param_no = fused_parameters.size(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(operand->name(), ".param_", param_no); - fused_param = fused_instructions_computation()->AddParameter( - CreateParameter(param_no, operand->shape(), param_name)); - AppendOperand(operand); + fused_param = AddFusionOperand(operand); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } @@ -1084,6 +1141,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: + case HloOpcode::kHostCompute: return true; default: { // Check if any of the called computations has a side effect. @@ -1121,6 +1179,19 @@ bool HloInstruction::HasSideEffect() const { return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->channel_name_ = channel_name.ToString(); + instruction->cost_estimate_ns_ = cost_estimate_ns; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; @@ -1131,6 +1202,40 @@ bool HloInstruction::HasSideEffect() const { return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); } +/* static */ std::unique_ptr HloInstruction::CreateGather( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(gather_indices); + instruction->gather_dimension_numbers_ = + MakeUnique(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); + return instruction; +} + +/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, @@ -1206,12 +1311,20 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBroadcast(shape, new_operands[0], dimensions_); break; + case HloOpcode::kBroadcastDimOne: + CHECK_EQ(new_operands.size(), 1); + clone = CreateBroadcastDimOne(shape, new_operands[0]); + break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); break; + case HloOpcode::kHostCompute: + clone = CreateHostCompute(shape, new_operands, channel_name_, + cost_estimate_ns_); + break; case HloOpcode::kConcatenate: clone = CreateConcatenate(shape, new_operands, dimensions(0)); break; @@ -1361,19 +1474,23 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kRecv: CHECK_EQ(new_operands.size(), 0); - clone = CreateRecv(shape, channel_id()); + // The shape is a tuple, but CreateRecv() wants the raw data shape. + clone = + CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); break; case HloOpcode::kRecvDone: CHECK_EQ(new_operands.size(), 1); clone = CreateRecvDone(new_operands[0]); break; + case HloOpcode::kGather: + CHECK_EQ(new_operands.size(), 2); + clone = CreateGather(shape, new_operands[0], new_operands[1], + *gather_dimension_numbers_, gather_window_bounds_); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } - clone->set_metadata(metadata_); - if (has_sharding()) { - clone->set_sharding(sharding()); - } + SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); return clone; } @@ -1710,6 +1827,11 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); + case HloOpcode::kGather: + return protobuf_util::ProtobufEquals(gather_dimension_numbers(), + other.gather_dimension_numbers()) && + gather_window_bounds() == other.gather_window_bounds(); + // FFT has various types & lengths. case HloOpcode::kFft: return fft_type() == other.fft_type() && @@ -1741,6 +1863,8 @@ bool HloInstruction::IdenticalSlowPath( // Remaining instructions with special values. case HloOpcode::kBitcast: + case HloOpcode::kBroadcastDimOne: + case HloOpcode::kDynamicUpdateSlice: return eq_shapes(shape(), other.shape()); case HloOpcode::kBroadcast: return eq_shapes(shape(), other.shape()) && @@ -1759,8 +1883,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDynamicSlice: return eq_shapes(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); @@ -1780,6 +1902,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kHostCompute: return false; } } @@ -2140,6 +2263,11 @@ std::vector HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } + if (gather_dimension_numbers_ != nullptr) { + extra.push_back(GatherDimensionNumbersToString()); + extra.push_back( + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); + } if (opcode() == HloOpcode::kFft) { extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); @@ -2232,14 +2360,18 @@ string HloInstruction::ToShortString() const { HloInstructionProto HloInstruction::ToProto() const { HloInstructionProto proto; + CHECK(unique_id_ != -1) + << "This instruction does not have a valid id. Please make sure the " + "instruction is inside a module before dumping it."; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_opcode(HloOpcodeString(opcode_)); *proto.mutable_shape() = shape_; for (const HloInstruction* operand : operands_) { - *proto.add_operand_names() = operand->name(); + proto.add_operand_ids(operand->unique_id()); } for (const HloInstruction* control : control_predecessors_) { - *proto.add_control_predecessor_names() = control->name(); + proto.add_control_predecessor_ids(control->unique_id()); } *proto.mutable_metadata() = metadata_; @@ -2249,11 +2381,11 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_parameter_number(parameter_number_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); - *proto.mutable_fused_instructions_computation() = - fused_instructions_computation()->ToProto(); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); } else { for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); + proto.add_called_computation_ids(computation->unique_id()); } } @@ -2271,6 +2403,14 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } + if (gather_dimension_numbers_ != nullptr) { + *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; + } + if (opcode() == HloOpcode::kGather) { + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2300,6 +2440,15 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_fft_length(fft_len); } + if (gather_dimension_numbers_ != nullptr) { + *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; + } + for (int64 bound : gather_window_bounds_) { + proto.add_gather_window_bounds(bound); + } + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); + return proto; } @@ -2543,6 +2692,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: return visitor->HandleBroadcast(this); + case HloOpcode::kBroadcastDimOne: + return visitor->HandleBroadcastDimOne(this); case HloOpcode::kPad: return visitor->HandlePad(this); case HloOpcode::kReshape: @@ -2565,6 +2716,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); + case HloOpcode::kHostCompute: + return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: @@ -2585,13 +2738,17 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSend(this); case HloOpcode::kSendDone: return visitor->HandleSendDone(this); + case HloOpcode::kGather: + return visitor->HandleGather(this); // These opcodes are not handled here. case HloOpcode::kTrace: break; } - return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s", - HloOpcodeString(opcode_).c_str()); + return InternalError( + "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " + "please file a bug for XLA.", + HloOpcodeString(opcode_).c_str()); } // Explicit instantiations. @@ -3268,6 +3425,26 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +string HloInstruction::GatherDimensionNumbersToString() const { + CHECK_NE(gather_dimension_numbers_.get(), nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3170746157fbcfa7d0a7eaba6d226d46691105f9..49aa07502996b698bb20f2c2e9d1d371d43d1793 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -179,20 +179,15 @@ class HloInstruction { // module: the module which will contain the instruction. The newly created // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. - // instruction_map: a map from instruction name to HloInstruction*. This map + // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. - // computation_map: a map from computation name to HloComputation*. This map + // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed instruction // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap& instruction_map, - const tensorflow::gtl::FlatMap& computation_map, - const std::function)>& - add_fused_computation); + const tensorflow::gtl::FlatMap& instruction_map, + const tensorflow::gtl::FlatMap& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -406,6 +401,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Creates a broadcast-size-one-dimensions instruction. + static std::unique_ptr CreateBroadcastDimOne( + const Shape& shape, HloInstruction* operand); + // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. // @@ -451,6 +450,12 @@ class HloInstruction { HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation); + static std::unique_ptr CreateGather( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -475,6 +480,12 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target); + // Creates a HostCompute instruction, which records host-side control and + // data dependencies for use in instruction scheduling. + static std::unique_ptr CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( @@ -486,6 +497,13 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -767,6 +785,10 @@ class HloInstruction { // // (We express the default options using an overload rather than a default // param because gdb ignores default params, but does resolve overloads.) + // + // TODO(b/73348663): Make ToString() adaptive to the size of the string by + // default, backing off on providing full information for very large strings, + // or provide a different name for a ToString-like function that does that. string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; @@ -802,6 +824,12 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. + // + // Precondition: opcode() == HloOpcode::kHostCompute + string channel_name() const { return channel_name_; } + // Returns feature_index field associated with the instruction. The index // represents the index of the feature dimension. // @@ -904,6 +932,13 @@ class HloInstruction { const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } + // Returns the sharding unique device, if any. + tensorflow::gtl::optional sharding_unique_device() const { + if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + return tensorflow::gtl::optional(); + } + return sharding_->UniqueDevice().ValueOrDie(); + } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { @@ -914,6 +949,16 @@ class HloInstruction { // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // When creating a new instruction which either replaces, or shifts up (kCopy + // insertion case), another instruction, we need to make sure the certain + // properties of the new instruction are copied into the derived one. As of + // today, the metadata and sharding will be propagated to the derived + // instruction. + void SetupDerivedInstruction(HloInstruction* derived_instruction) const; + + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + // Merges the fused instructions from 'instruction_to_merge' into the // fused instruction set of 'this', updating operands as necessary. // @@ -1086,6 +1131,19 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + + tensorflow::gtl::ArraySlice gather_window_bounds() const { + CHECK_EQ(opcode(), HloOpcode::kGather); + return gather_window_bounds_; + } + + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1350,6 +1408,9 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; + std::unique_ptr gather_dimension_numbers_; + std::vector gather_window_bounds_; + // Describes FFT type for an FFT instruction. FftType fft_type_ = FftType::FFT; @@ -1388,6 +1449,12 @@ class HloInstruction { // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; + // Name to use for host send/recv channels, only present for kHostCompute. + string channel_name_; + + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_ = 0; + // Computations called by this instruction. std::vector called_computations_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 94e9bfe56eb445ec0b459a55342cd3cc4c6f68ef..f2980d309d01fdf3b3e601bc260a0ad0895b3064 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1271,5 +1271,77 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, StringifyGather_0) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=4, window_bounds={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, StringifyGather_1) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=2, window_bounds={30,29,28,27,26}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 60270b0595dcfca8f1fcea5ab0914428880f35b5..08b9a29aeda2ee612d49b0788acf8438a25eb6a3 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -83,6 +83,11 @@ HloComputation* HloModule::AddComputationInternal( for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); + computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -145,6 +150,21 @@ void HloModule::ReplaceComputations( } break; } + case HloOpcode::kConditional: { + HloComputation* new_true_computation = + tensorflow::gtl::FindWithDefault( + replacements, instruction->true_computation(), nullptr); + if (new_true_computation != nullptr) { + instruction->set_true_computation(new_true_computation); + } + HloComputation* new_false_computation = + tensorflow::gtl::FindWithDefault( + replacements, instruction->false_computation(), nullptr); + if (new_false_computation != nullptr) { + instruction->set_false_computation(new_false_computation); + } + break; + } case HloOpcode::kSelectAndScatter: { HloComputation* new_select = tensorflow::gtl::FindWithDefault( replacements, instruction->select(), nullptr); @@ -189,90 +209,36 @@ string HloModule::ToString(const HloPrintOptions& options) const { HloModuleProto HloModule::ToProto() const { HloModuleProto proto; + proto.set_id(unique_id_); proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); + proto.set_entry_computation_id(entry_computation_->unique_id()); for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are added when the fusion instructions are created by - // HloInstruction::CreateFromProto. - if (computation->IsFusionComputation()) { - continue; - } HloComputationProto computation_proto = computation->ToProto(); + if (computation->name() == entry_computation_->name()) { + *proto.mutable_program_shape() = computation_proto.program_shape(); + } proto.add_computations()->Swap(&computation_proto); } return proto; } -namespace { - -// Construct a ProgramShape matching the shape of the parameters and root of the -// given module's entry computation. -StatusOr ProgramShapeFromProto(const HloModuleProto& module) { - const HloComputationProto* entry_computation = nullptr; - for (const HloComputationProto& computation : module.computations()) { - if (computation.name() == module.entry_computation_name()) { - entry_computation = &computation; - break; - } - } - TF_RET_CHECK(entry_computation != nullptr) - << "No computation with entry computation name" - << module.entry_computation_name(); - - tensorflow::gtl::FlatMap> parameters; - const HloInstructionProto* root = nullptr; - for (const HloInstructionProto& instruction : - entry_computation->instructions()) { - if (instruction.name() == entry_computation->root_name()) { - TF_RET_CHECK(root == nullptr) << "Entry computation has more than " - "one instruction with (root) name " - << instruction.name(); - root = &instruction; - } - if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { - TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) - << "Entry computation has more than one parameter instruction " - "with parameter number " - << instruction.parameter_number(); - parameters[instruction.parameter_number()] = {instruction.name(), - &instruction.shape()}; - } - } - TF_RET_CHECK(root != nullptr) - << "Entry computation is missing root instruction named " - << entry_computation->root_name(); - - ProgramShape program_shape; - *program_shape.mutable_result() = root->shape(); - for (int64 i = 0; i < parameters.size(); ++i) { - TF_RET_CHECK(ContainsKey(parameters, i)) - << "Entry computation missing parameter number " << i; - const string& name = parameters.at(i).first; - const Shape& shape = *parameters.at(i).second; - *program_shape.add_parameters() = shape; - program_shape.add_parameter_names(name); - } - - return std::move(program_shape); -} - -} // namespace - /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, const VersionedComputationHandle& entry_computation_handle) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, - ProgramShapeFromProto(proto)); + TF_RET_CHECK(proto.has_program_shape()) + << "No program shape found in the proto"; + const auto& expected_program_shape = proto.program_shape(); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = module_config.entry_computation_layout().parameter_layout(i).shape(); - TF_RET_CHECK( - ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), + parameter_shape)) << "HloModuleConfig has different shape for parameter " << i << " than the HLO module. Expected: " << ShapeUtil::HumanStringWithLayout( @@ -281,7 +247,8 @@ StatusOr> HloModule::CreateFromProto( } const Shape& result_shape = module_config.entry_computation_layout().result_layout().shape(); - TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + TF_RET_CHECK( + ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " "Expected: " << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) @@ -290,26 +257,20 @@ StatusOr> HloModule::CreateFromProto( auto module = MakeUnique(proto.name(), entry_computation_handle, module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map, - /*add_fused_computation=*/ - [&module](std::unique_ptr fused_computation) { - module->AddComputationInternal(std::move(fused_computation), - /*is_entry=*/false, - /*uniquify_names=*/false); - })); + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); - TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); - string computation_name = computation->name(); + int64 computation_id = computation_proto.id(); + TF_RET_CHECK(computation_id != -1); + TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_name] = module->AddComputationInternal( + computation_map[computation_id] = module->AddComputationInternal( std::move(computation), - /*is_entry=*/proto.entry_computation_name() == computation_name, + /*is_entry=*/proto.entry_computation_id() == computation_id, /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -319,10 +280,6 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; for (HloComputation* computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); @@ -338,11 +295,13 @@ StatusOr> HloModule::CreateFromProto( /* static */ StatusOr HloModule::CreateModuleConfigFromProto( - const HloModuleProto& module) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - ProgramShapeFromProto(module)); + const HloModuleProto& module, const DebugOptions& debug_options) { + TF_RET_CHECK(module.has_program_shape()) + << "No program shape found in the proto"; + const auto& program_shape = module.program_shape(); HloModuleConfig module_config(program_shape); + module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. @@ -563,6 +522,18 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { return module; } +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { + HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); + TF_CHECK_OK( + clone->root_instruction()->Accept([this](HloInstruction* instruction) { + instruction->ReplaceCalledComputations([this](HloComputation* callee) { + return DeepCloneComputation(callee); + }); + return Status::OK(); + })); + return clone; +} + uint64 HloModule::RandomNew64() const { tensorflow::mutex_lock l(rng_mutex_); return rng_(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 4bfe8d89ce0a285de6d05d4867aaa6b266d78d12..9f7f25202ba42b14e995ed5c47d1012dabc69332 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -85,6 +85,10 @@ class HloModule { // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; + // Performs a deep clone of the computation, by recursively cloning all + // the called computations as well. + HloComputation* DeepCloneComputation(HloComputation* computation); + // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); @@ -99,7 +103,7 @@ class HloModule { return config_.mutable_entry_computation_layout(); } - ComputationLayout entry_computation_layout() const { + const ComputationLayout& entry_computation_layout() const { return config_.entry_computation_layout(); } @@ -168,7 +172,7 @@ class HloModule { // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. static StatusOr CreateModuleConfigFromProto( - const HloModuleProto& module); + const HloModuleProto& module, const DebugOptions& debug_options); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. @@ -183,11 +187,6 @@ class HloModule { // Returns a randomly generated uint64. uint64 RandomNew64() const; - // Returns the unique name for a computation in this module. - string GetUniqueCompuationName(const string& prefix) { - return computation_name_uniquer_.GetUniqueName(prefix); - } - // Returns the NameUniquer for uniquing instruction names in this module. NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 822e2f1f53e5ee460b88c2241ecf7f6b91ef608b..4205b0402cb8b2c31141d65be652cd84c22e7262 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -40,7 +40,7 @@ void HloModuleConfig::SetDefaultComputationLayout( string HloModuleConfig::compilation_cache_key() const { string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); + tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index a5ee895e48448fbb8fa3879dc1b6764c1f9f6966..586a03d412681cacdd780f48e77baf4cd4c51415 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -63,9 +63,19 @@ class HloModuleConfig { return &(*entry_computation_layout_); } - // Sets/returns whether to enable HLO-level profiling. - bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } - void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } + // Returns whether to enable HLO-level profiling. + bool hlo_profiling_enabled() const { + return debug_options_.xla_hlo_profile(); + } + + // Sets/returns whether this is a "host module". Host modules are used to + // record the data- and control-flow dependencies of host side computation + // that communicates with compiled code. They are used for analysis and + // scheduling purposes, but no code is generated. + bool is_host_module() const { return is_host_module_; } + void set_is_host_module(bool is_host_module) { + is_host_module_ = is_host_module; + } // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } @@ -101,8 +111,8 @@ class HloModuleConfig { tensorflow::gtl::optional entry_computation_layout_; - // Whether to enable HLO-level profiling. - bool hlo_profiling_enabled_ = false; + // Whether this is a 'host module'. + bool is_host_module_ = false; // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..54c34ce116651608e6d91cdcba9c708ca3a5f75e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" + +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +string HloModuleGroupMetadata::TrackedInstruction::ToString() const { + string repr = + (instruction_ != nullptr) ? instruction_->ToShortString() : "NULL"; + switch (kind_) { + case ComputationKind::kInvalid: + repr += ":INVALID"; + break; + case ComputationKind::kWhileCondition: + repr += ":WHILE_CONDITION"; + break; + case ComputationKind::kWhileBody: + repr += ":WHILE_BODY"; + break; + case ComputationKind::kConditionalTrue: + repr += ":CONDITIONAL_TRUE"; + break; + case ComputationKind::kConditionalFalse: + repr += ":CONDITIONAL_FALSE"; + break; + } + return repr; +} + +/* static */ StatusOr> +HloModuleGroupMetadata::Build(const std::vector& modules) { + auto metadata = absl::make_unique(modules); + TF_RETURN_IF_ERROR(metadata->Build()); + return std::move(metadata); +} + +Status HloModuleGroupMetadata::Build() { + TF_RETURN_IF_ERROR(RecordInstructions()); + TF_RETURN_IF_ERROR(VerifyChannelInstructions()); + + // Record all companion while instructions. + const auto visitor = [this](HloInstruction* hlo) -> Status { + // We only need to process if the instruction is within the computation + // of a companion instruction, like in the condition or body computation + // of a While. + const TrackedInstruction* tracked = GetTrackedInstruction(hlo->parent()); + if (tracked == nullptr) { + return Status::OK(); + } + // Add the parent computation of this channel instruction and its peer + // computation (both must be while computations) as companions. + if (IsChannelInstruction(hlo)) { + HloComputation* peer_computation = PeerComputation(hlo); + const TrackedInstruction* peer_tracked = + GetTrackedInstruction(peer_computation); + TF_RET_CHECK(peer_tracked != nullptr) + << "Peer instruction is not a possible companion"; + TF_RET_CHECK(*tracked == *peer_tracked) + << "Peer instruction does not match the computation kind"; + TF_RETURN_IF_ERROR( + AddCompanion(tracked->instruction(), peer_tracked->instruction())); + } + + // Add the parents of companion instructions (they must be all of the same + // kind of instructions, opcode wise) as companions. + if (IsCompanionInstruction(hlo)) { + for (HloInstruction* companion : Companions(hlo)) { + const TrackedInstruction* companion_tracked = + GetTrackedInstruction(companion->parent()); + TF_RET_CHECK(companion_tracked != nullptr); + TF_RET_CHECK(*tracked == *companion_tracked); + TF_RETURN_IF_ERROR(AddCompanion(tracked->instruction(), + companion_tracked->instruction())); + } + } + return Status::OK(); + }; + + // Visit the computations in postorder so that the companion information grows + // from inner computations to outer ones. + for (HloModule* module : modules_) { + for (HloComputation* computation : module->MakeComputationPostOrder()) { + TF_RETURN_IF_ERROR(computation->Accept(visitor)); + } + } + return Status::OK(); +} + +bool HloModuleGroupMetadata::IsChannelInstruction( + const HloInstruction* instruction) const { + switch (instruction->opcode()) { + case HloOpcode::kSend: + case HloOpcode::kRecv: + case HloOpcode::kSendDone: + case HloOpcode::kRecvDone: + return true; + default: + return false; + } +} + +bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const { + return companion_set_index_.count(hlo) > 0; +} + +bool HloModuleGroupMetadata::InstructionCommunicates( + HloInstruction* hlo) const { + return IsChannelInstruction(hlo) || IsCompanionInstruction(hlo); +} + +const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( + int64 channel_id) const { + CHECK(channel_id_map_.find(channel_id) != channel_id_map_.end()); + return channels_[channel_id_map_.at(channel_id)]; +} + +HloComputation* HloModuleGroupMetadata::PeerComputation( + const HloInstruction* instruction) const { + CHECK(IsChannelInstruction(instruction)); + const Channel& channel = GetChannel(instruction->channel_id()); + switch (instruction->opcode()) { + case HloOpcode::kSend: + case HloOpcode::kSendDone: + return channel.recv->parent(); + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + return channel.send->parent(); + default: + LOG(FATAL) << "opcode not supported"; + } +} + +std::vector +HloModuleGroupMetadata::GetCompanionsPath(const HloInstruction* hlo) const { + std::vector path; + const HloComputation* parent = hlo->parent(); + const TrackedInstruction* companion; + while ((companion = GetTrackedInstruction(parent)) != nullptr) { + parent = companion->instruction()->parent(); + path.push_back(*companion); + } + return path; +} + +bool HloModuleGroupMetadata::CheckCompanionPathsCompatibility( + const std::vector& path0, + const std::vector& path1) const { + if (path0.size() != path1.size()) { + VLOG(5) << "Companion path size do not match: " << path0.size() + << " != " << path1.size(); + return false; + } + for (int64 i = 0; i < path0.size(); ++i) { + if (path0[i] != path1[i]) { + VLOG(5) << "Companion instructions at path index " << i + << " do not have the same opcode: " << path0[i].ToString() + << " vs " << path1[i].ToString(); + return false; + } + } + return true; +} + +int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { + for (int64 i = 0; i < modules_.size(); ++i) { + if (modules_[i] == module) { + return i; + } + } + LOG(FATAL) << "unknown module"; +} + +Status HloModuleGroupMetadata::RecordInstructions() { + const auto visitor = [this](HloInstruction* hlo) -> Status { + if (hlo->opcode() == HloOpcode::kWhile) { + tracked_instructions_[hlo->while_condition()] = + TrackedInstruction(hlo, ComputationKind::kWhileCondition); + tracked_instructions_[hlo->while_body()] = + TrackedInstruction(hlo, ComputationKind::kWhileBody); + } else if (hlo->opcode() == HloOpcode::kConditional) { + tracked_instructions_[hlo->true_computation()] = + TrackedInstruction(hlo, ComputationKind::kConditionalTrue); + tracked_instructions_[hlo->false_computation()] = + TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } + if (!IsChannelInstruction(hlo)) { + return Status::OK(); + } + + // Add a new channel if needed. + if (channel_id_map_.find(hlo->channel_id()) == channel_id_map_.end()) { + channels_.emplace_back(); + channels_.back().id = hlo->channel_id(); + channel_id_map_[hlo->channel_id()] = channels_.size() - 1; + max_channel_id_ = std::max(max_channel_id_, hlo->channel_id()); + } + Channel& channel = channels_[channel_id_map_[hlo->channel_id()]]; + + if (hlo->opcode() == HloOpcode::kSend) { + TF_RET_CHECK(channel.send == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple send instructions"; + channel.send = hlo; + } + if (hlo->opcode() == HloOpcode::kRecv) { + TF_RET_CHECK(channel.recv == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple recv instructions"; + channel.recv = hlo; + } + if (hlo->opcode() == HloOpcode::kSendDone) { + TF_RET_CHECK(channel.send_done == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple send-done instructions"; + channel.send_done = hlo; + } + if (hlo->opcode() == HloOpcode::kRecvDone) { + TF_RET_CHECK(channel.recv_done == nullptr) + << "channel id " << hlo->channel_id() + << " is used by multiple recv-done instructions"; + channel.recv_done = hlo; + } + return Status::OK(); + }; + + for (HloModule* module : modules_) { + for (auto* computation : module->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(visitor)); + } + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, + HloInstruction* instruction2) { + TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || + instruction1->opcode() == HloOpcode::kConditional); + VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " + << instruction2->ToString(); + + if (!ContainsKey(companion_set_index_, instruction1) && + !ContainsKey(companion_set_index_, instruction2)) { + companion_sets_.push_back( + absl::make_unique>()); + auto companion_set = companion_sets_.back().get(); + companion_set->insert(instruction1); + companion_set->insert(instruction2); + companion_set_index_[instruction1] = companion_sets_.size() - 1; + companion_set_index_[instruction2] = companion_sets_.size() - 1; + } else if (!ContainsKey(companion_set_index_, instruction1)) { + companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_set_index_[instruction1] = companion_set_index_[instruction2]; + } else if (!ContainsKey(companion_set_index_, instruction2)) { + companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_set_index_[instruction2] = companion_set_index_[instruction1]; + } else if (companion_set_index_[instruction1] != + companion_set_index_[instruction2]) { + companion_sets_[companion_set_index_[instruction1]]->insert( + Companions(instruction2).begin(), Companions(instruction2).end()); + int64 index_to_remove = companion_set_index_[instruction2]; + for (HloInstruction* hlo : Companions(instruction2)) { + companion_set_index_[hlo] = companion_set_index_[instruction1]; + } + companion_sets_.erase(companion_sets_.begin() + index_to_remove); + } + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyChannelInstructions() { + for (const Channel& channel : channels_) { + if (channel.send == nullptr) { + return FailedPrecondition("missing send for id : %lld", channel.id); + } + if (channel.recv == nullptr) { + return FailedPrecondition("missing recv for id : %lld", channel.id); + } + if (channel.send_done == nullptr) { + return FailedPrecondition("missing send-done for id : %lld", channel.id); + } + if (channel.recv_done == nullptr) { + return FailedPrecondition("missing recv-done for id : %lld", channel.id); + } + } + + // Check if the shapes match for each channel. + for (const Channel& channel : channels_) { + const Shape& send_shape = channel.send->operand(0)->shape(); + const Shape& recv_shape = channel.recv_done->shape(); + if (!ShapeUtil::Compatible(send_shape, recv_shape)) { + return FailedPrecondition("send/recv shapes do not match"); + } + const HloModule* send_module = channel.send->parent()->parent(); + const HloModule* send_done_module = channel.send_done->parent()->parent(); + if (send_module != send_done_module) { + return FailedPrecondition( + "send and send-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + } + const HloModule* recv_module = channel.recv->parent()->parent(); + const HloModule* recv_done_module = channel.recv_done->parent()->parent(); + if (recv_module != recv_done_module) { + return FailedPrecondition( + "recv and recv-done (channel=%lld) must be on the same device: %lld " + "vs. %lld", + channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + } + if (send_module == recv_module) { + return FailedPrecondition( + "send and recv (channel=%lld) must be on different devices: %lld", + channel.id, GetModuleId(send_module)); + } + } + + // Check if channel instructions are used only in allowed computations. + const auto allowed = [this](HloInstruction* hlo) { + HloComputation* computation = hlo->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return true; + } + return false; + }; + for (const Channel& channel : channels_) { + if (!allowed(channel.send) || !allowed(channel.send_done) || + !allowed(channel.recv) || !allowed(channel.recv_done)) { + return FailedPrecondition("channel is used in disallowed computation"); + } + } + // Check if the nest levels match for each channel. + for (const Channel& channel : channels_) { + std::vector path = GetCompanionsPath(channel.send); + if (!CheckCompanionPathsCompatibility( + path, GetCompanionsPath(channel.send_done)) || + !CheckCompanionPathsCompatibility(path, + GetCompanionsPath(channel.recv)) || + !CheckCompanionPathsCompatibility( + path, GetCompanionsPath(channel.recv_done))) { + return FailedPrecondition( + "Nest companion paths do not match for channel %lld", channel.id); + } + } + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..c48a7ab0b59269474f7406ef24a249355528e085 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -0,0 +1,239 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class for bookkeeping the information on the given modules, in particular on +// the interaction between computations. +// +// Companion instructions are one of the information collected as we build the +// metadata. For example, for each While instruction, companion instructions +// refer to a set of While instructions in other computations that communicate +// with each other. +// In the example below with 3 modules, {While_0, While_2, While_5}, {While_1, +// While_4}, {While_3, While_6} are companion sets. +// +// +// While_0() { While_2() { While_5() { +// While_1() { Send(0) } While_3() { Send(1) } While_6() { Recv(1) } +// } While_4() { Recv(0) } +// } +// +// Companion instructions are used to detect cycles in the graph and also for +// global scheduling. +class HloModuleGroupMetadata { + public: + // The kind of companion computation a given instruction can be within. + enum class ComputationKind { + kInvalid, + kWhileCondition, + kWhileBody, + kConditionalTrue, + kConditionalFalse, + }; + + // Tracks the instruction mapped to a given computation, and the computation + // kind. + // For example, a body computation of a while instruction, will generate a + // TrackedInstruction with instruction being the while instruction, and + // kind being ComputationKind::kWhileBody. + class TrackedInstruction { + public: + TrackedInstruction() = default; + TrackedInstruction(HloInstruction* instruction, ComputationKind kind) + : instruction_(instruction), kind_(kind) {} + + bool operator==(const TrackedInstruction& rhs) const { + return instruction_->opcode() == rhs.instruction_->opcode() && + kind_ == rhs.kind_; + } + bool operator!=(const TrackedInstruction& rhs) const { + return !operator==(rhs); + } + + HloInstruction* instruction() const { return instruction_; } + + string ToString() const; + + private: + HloInstruction* instruction_ = nullptr; + ComputationKind kind_ = ComputationKind::kInvalid; + }; + + // Represents a channel and the 4 instructions that form the channel. + struct Channel { + int64 id = -1; + HloInstruction* send = nullptr; + HloInstruction* recv = nullptr; + HloInstruction* send_done = nullptr; + HloInstruction* recv_done = nullptr; + }; + + explicit HloModuleGroupMetadata(const std::vector& modules) + : modules_(modules) {} + + ~HloModuleGroupMetadata() = default; + + // Build and return the metadata for the given modules. + static StatusOr> Build( + const std::vector& modules); + + // Returns true if the instruction is one of the 4 channel instructions (Send, + // Recv, SendDone, RecvDone). + bool IsChannelInstruction(const HloInstruction* instruction) const; + + // Returns true if the instruction is a companion instruction. See the class + // comment above on companion instructions. + bool IsCompanionInstruction(HloInstruction* hlo) const; + + // Returns true if the instruction is either a channel instruction or a + // companion instruction. + bool InstructionCommunicates(HloInstruction* hlo) const; + + // Returns the Channel instance for the given channel id. + const Channel& GetChannel(int64 channel_id) const; + + // Returns the computation that contains the peer channel instructions for + // the given instruction. + // + // Precondition: IsChannelInstruction(instruction) is true. + HloComputation* PeerComputation(const HloInstruction* instruction) const; + + // Returns the path of the nested companion instructions, in terms of HLO + // instructions. The path goes from inner to outer companions. + // The returned path does not include the input hlo instruction, in case it + // is a companion instruction. + std::vector GetCompanionsPath( + const HloInstruction* hlo) const; + + // Checks whether two companion paths (as returned by the GetCompanionsPath() + // API) are compatible. The two paths are compatible if the sequence of + // opcodes, and the companion kinds, of the two paths matches. + bool CheckCompanionPathsCompatibility( + const std::vector& path0, + const std::vector& path1) const; + + // Returns the unique integer for each module. The returned id is the index of + // the module in the module vector. + int64 GetModuleId(const HloModule* module) const; + + // Returns the companion instructions for the given instruction. + // + // Precondition: IsCompanionWhile(instruction) is true. + const std::unordered_set& Companions( + HloInstruction* instruction) const { + CHECK_EQ(companion_set_index_.count(instruction), 1); + return companion_set(companion_set_index_.at(instruction)); + } + + // Returns the companion set at the given index. + const std::unordered_set& companion_set(int64 index) const { + CHECK_LT(index, companion_sets_.size()); + return *companion_sets_[index]; + } + + // Returns the companion set index of the given instruction. + int64 companion_set_index(HloInstruction* instruction) const { + return companion_set_index_.at(instruction); + } + + // Returns the list of all companion sets in the HLO module group. + const std::vector>>& + companion_sets() const { + return companion_sets_; + } + + // Returns all channels in the module group. + const std::vector& channels() const { return channels_; } + + // Returns the maximum channel id used in the module group. + int64 max_channel_id() const { return max_channel_id_; } + + private: + Status Build(); + + // Record all channel instructions and While instructions. + Status RecordInstructions(); + + // Verifies the given HloModules are well-formed and follow the specification, + // in particular with respect to using channel instructions. + // + // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). + // * The shape of channel instructions match. + // * The nest level of channel instructions match. + // * Channel instructions are used in allowed computations; i.e., in the + // entry computation of the module or condition/body of While computations. + // + // TODO(b/62064342): Currently, HloModuleGroupScheduler checks if there is a + // cycle in the graph, but it would be good to verify here. + Status VerifyChannelInstructions(); + + // Adds metadata that the given two instructions are companions. + Status AddCompanion(HloInstruction* instruction1, + HloInstruction* instruction2); + + // Retrieves a pointer to the stored TrackedInstruction associated with a + // tracked computation, or nullptr in case such computation is not tracked. + const TrackedInstruction* GetTrackedInstruction( + const HloComputation* computation) const { + auto it = tracked_instructions_.find(computation); + return it != tracked_instructions_.end() ? &it->second : nullptr; + } + + // List of all companion instructions sets in the module. + std::vector>> + companion_sets_; + + // Map from each companion while instruction to the index into companion_set_. + tensorflow::gtl::FlatMap companion_set_index_; + + // Map from computation to the instruction using it (a kWhile, kConditional). + tensorflow::gtl::FlatMap + tracked_instructions_; + + // All channels in the module. + std::vector channels_; + + // Map from channel ids to the index in channels_. + tensorflow::gtl::FlatMap channel_id_map_; + + // The maximum channel id used in the module group. + int64 max_channel_id_ = -1; + + // The modules that this metadata was built from. + const std::vector& modules_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..289c96b0a7b90c5f8a122cd3fc327a5762099106 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -0,0 +1,316 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group_util.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +std::vector HloModuleGroupUtil::GlobalPredecessors( + HloInstruction* instruction) { + std::vector predecessors; + + // Adds to the unique predecessors list and also add companion instructions + // if the given predecessor has those. + auto add_unique_predecessor = [&](HloInstruction* predecessor) { + if (std::find(predecessors.begin(), predecessors.end(), predecessor) != + predecessors.end()) { + return; + } + if (!metadata_.IsCompanionInstruction(predecessor)) { + predecessors.push_back(predecessor); + return; + } + for (HloInstruction* companion : metadata_.Companions(predecessor)) { + predecessors.push_back(companion); + } + }; + + // If the given instruction is a companion instruction, we need to find the + // predecessors of all of its companion instructions. + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(instruction)) { + for (HloInstruction* companion : metadata_.Companions(instruction)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(instruction); + } + + for (HloInstruction* hlo : instruction_group) { + for (HloInstruction* operand : hlo->operands()) { + add_unique_predecessor(operand); + } + for (HloInstruction* control_predecessor : hlo->control_predecessors()) { + add_unique_predecessor(control_predecessor); + } + } + if (instruction->opcode() == HloOpcode::kRecvDone) { + // Send is a remote predecessor of RecvDone. + HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + add_unique_predecessor(send); + } + if (instruction->opcode() == HloOpcode::kSend) { + // Recv is a remote predecessor of Send. + HloInstruction* recv_done = + metadata_.GetChannel(instruction->channel_id()).recv_done; + CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + CHECK_EQ(recv_done->operand_count(), 1); + HloInstruction* recv = recv_done->mutable_operand(0); + add_unique_predecessor(recv); + } + return predecessors; +} + +std::vector HloModuleGroupUtil::GlobalSuccessors( + HloInstruction* instruction) { + std::vector successors; + + // Adds to the unique successors list and also add companion instructions + // if the given successor has those. + auto add_unique_successor = [&](HloInstruction* successor) { + if (std::find(successors.begin(), successors.end(), successor) != + successors.end()) { + return; + } + if (!metadata_.IsCompanionInstruction(successor)) { + successors.push_back(successor); + return; + } + for (HloInstruction* companion : metadata_.Companions(successor)) { + successors.push_back(companion); + } + }; + + // If the given instruction is a companion instruction, we need to find the + // successors of all of its companion instructions. + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(instruction)) { + for (HloInstruction* companion : metadata_.Companions(instruction)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(instruction); + } + + for (HloInstruction* hlo : instruction_group) { + for (HloInstruction* user : hlo->users()) { + add_unique_successor(user); + } + for (HloInstruction* control_successor : hlo->control_successors()) { + add_unique_successor(control_successor); + } + } + if (instruction->opcode() == HloOpcode::kRecv) { + // Send is a remote successor of Recv. + const HloInstruction* recv_done = instruction->users().front(); + CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; + add_unique_successor(send); + } + if (instruction->opcode() == HloOpcode::kSend) { + // RecvDone is a remote successor of Send. + HloInstruction* recv_done = + metadata_.GetChannel(instruction->channel_id()).recv_done; + add_unique_successor(recv_done); + } + return successors; +} + +std::vector HloModuleGroupUtil::RootInstructions( + tensorflow::gtl::ArraySlice computations) { + std::vector roots; + for (HloComputation* computation : computations) { + for (HloInstruction* instruction : computation->instructions()) { + if (GlobalSuccessors(instruction).empty()) { + roots.push_back(instruction); + } + } + } + return roots; +} + +Status HloModuleGroupUtil::VisitTopologicalOrder( + VisitStates* visit_state, const VisitFunction& visit_function, + HloInstruction* root) { + // Stack of HLO instructions visited in DFS order. + std::stack stack; + stack.push(root); + + while (!stack.empty()) { + HloInstruction* hlo = stack.top(); + + // Find the instruction group of the currently visited instruction. The + // instruction group represents all companion instructions of the + // current instruction, and are considered to be a single entity for the + // purpose of the traversal (i.e., they must always be in the same visit + // state). + std::vector instruction_group; + if (metadata_.IsCompanionInstruction(hlo)) { + for (HloInstruction* companion : metadata_.Companions(hlo)) { + instruction_group.push_back(companion); + } + } else { + instruction_group.push_back(hlo); + } + + if ((*visit_state)[hlo] == VisitState::kVisited) { + // All instructions in the group must be in the same state. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisited); + } + stack.pop(); + continue; + } + + if ((*visit_state)[hlo] == VisitState::kVisiting) { + TF_RETURN_IF_ERROR(visit_function(hlo, instruction_group)); + + // Set the visit state of all instructions in the group to kVisited. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kVisiting); + (*visit_state)[instruction] = VisitState::kVisited; + } + stack.pop(); + continue; + } + + // Set the visit state of all instructions in the group to kVisiting. + for (HloInstruction* instruction : instruction_group) { + TF_RET_CHECK((*visit_state)[instruction] == VisitState::kNotVisited) + << instruction->ToString(); + (*visit_state)[instruction] = VisitState::kVisiting; + } + + // For each instruction in the group, visit its predecessors (operands, + // control predecessors and remote predecessors). + for (HloInstruction* instruction : instruction_group) { + for (HloInstruction* predecessor : GlobalPredecessors(instruction)) { + // Visiting a node that is already being visited implies that there is + // a cycle. Generate an error with the list of instructions in the + // cycle. + if ((*visit_state)[predecessor] == VisitState::kVisiting) { + string cyclic_instructions; + for (const auto& state : *visit_state) { + if (state.second == VisitState::kVisiting) { + tensorflow::strings::StrAppend(&cyclic_instructions, + state.first->ToString(), "\n"); + } + } + // TODO(b/64305524): Improve the error message to print out the + // instructions in a deterministic order that forms the cycle. + return FailedPrecondition( + "Cross-computation cycle detected via communicating nodes. The " + "cycle contains the node %s. The cycle is found among the " + "following nodes. Note that the order of the nodes is arbitrary " + "and that the list may include nodes that are not part of the " + "cycle.\n%s", + predecessor->ToString().c_str(), cyclic_instructions.c_str()); + } + stack.push(predecessor); + } + } + } + + return Status::OK(); +} + +Status HloModuleGroupUtil::VerifyComputations( + tensorflow::gtl::ArraySlice computations) { + auto visit_function = + [&](HloInstruction* instruction, + const std::vector& instruction_group) { + return Status::OK(); + }; + int64 instructions_count = 0; + VisitStates visit_states; + for (HloComputation* computation : computations) { + // Visit all instructions, and not just from the root instruction of the + // computation. This allows us to detect dead cycles (i.e., cycles that + // are not reachable from the root) or to enforce an order for the + // communication instructions that are not reachable from any roots. + for (HloInstruction* instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + VisitTopologicalOrder(&visit_states, visit_function, instruction)); + } + instructions_count += computation->instruction_count(); + } + + // Check if all instructions are visited and are in the visited state. + TF_RET_CHECK(visit_states.size() == instructions_count); + for (auto& state : visit_states) { + TF_RET_CHECK(state.second == VisitState::kVisited); + } + + return Status::OK(); +} + +StatusOr> +HloModuleGroupUtil::ComputeReachability( + tensorflow::gtl::ArraySlice computations) { + std::list post_order; + auto visit_function = + [&](HloInstruction* instruction, + const std::vector& instruction_group) { + post_order.insert(post_order.end(), instruction_group.begin(), + instruction_group.end()); + return Status::OK(); + }; + HloModuleGroupUtil::VisitStates visit_states; + for (HloInstruction* root : RootInstructions(computations)) { + TF_RETURN_IF_ERROR( + VisitTopologicalOrder(&visit_states, visit_function, root)); + } + auto reachability = absl::make_unique(post_order); + for (HloInstruction* hlo : post_order) { + reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); + } + return std::move(reachability); +} + +void HloModuleGroupUtil::UpdateReachabilityThroughInstruction( + HloInstruction* instruction, HloReachabilityMap* reachability_map) { + std::queue worklist; + worklist.push(instruction); + + while (!worklist.empty()) { + HloInstruction* item = worklist.front(); + worklist.pop(); + if (reachability_map->SetReachabilityToUnion(GlobalPredecessors(item), + item)) { + for (HloInstruction* successor : GlobalSuccessors(item)) { + worklist.push(successor); + } + } + } +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h new file mode 100644 index 0000000000000000000000000000000000000000..c25ca1aff50b288f3ac3885cbed53e7ba9768430 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +// Collection of utilities for handling HloModuleGroups. +class HloModuleGroupUtil { + public: + explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata) + : metadata_(metadata) {} + + // Returns all unique predecessors of the instruction. This includes: + // * predecessors in the same computation: operands and control predecessors + // * Recv is a predecessor of Send + // * Send is a predecessor of RecvDone + // * predecessors of companions (if the instruction is a companion while) + // * predecessors' companions (for any predecessor that is a companion while) + std::vector GlobalPredecessors(HloInstruction* instruction); + + // Returns all unique successors of the instruction. This includes: + // * successors in the same computation: users and control successors + // * Send is a successor of Recv + // * RecvDone is a predecessor of Send + // * successors of companions (if the instruction is a companion while) + // * successors' companions (for any successor that is a companion while) + std::vector GlobalSuccessors(HloInstruction* instruction); + + // Returns the root instructions of the computations. + std::vector RootInstructions( + tensorflow::gtl::ArraySlice computations); + + // Visit state of each instruction during DFS traversal. + enum VisitState { + kNotVisited = 0, + kVisiting, + kVisited, + }; + + // Function called on each instruction group during the DFS traversal. See the + // comment for VisitTopologicalOrder()). + using VisitFunction = std::function& instruction_group)>; + + // Given the hlo instruction as the root, recursively visits all its + // predecessor instructions in DFS order to visit nodes in topological order. + // + // Note that the DFS traversal does not only visit nodes in the same + // computation (parent of the root instruction), but also visits nodes in + // different computations connected via communication instructions. During the + // traversal, companion While instructions (see the class comment in + // HloModuleGroupMetadata) are treated as a single instruction (called + // instruction group, which contains only a single instruction if the visiting + // node is not a companion while) -- visiting one of the instructions in the + // group effectively visits all other instructions in the group, and then all + // predecessor instructions of the group are visited. + // + // * visit_state: map from each instruction to its visit state. + // * visit_function: function called when each instruction group. + // * root: the root instruction of the traversal. + using VisitStates = tensorflow::gtl::FlatMap; + Status VisitTopologicalOrder(VisitStates* visit_state, + const VisitFunction& visit_function, + HloInstruction* root); + + // Verifies that the computations are well-formed (e.g., no cycles). + Status VerifyComputations( + tensorflow::gtl::ArraySlice computations); + + // Below Reachability utils resemble those in HloComputation, except that + // they can handle instructions across multiple computations. + // + // Creates the reachability map for the instructions in the computations. + StatusOr> ComputeReachability( + tensorflow::gtl::ArraySlice computations); + + // Updates the reachability of the given instruction, taking the global + // predeccessorss and successors into account. + void UpdateReachabilityThroughInstruction( + HloInstruction* instruction, HloReachabilityMap* reachability_map); + + private: + const HloModuleGroupMetadata& metadata_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 3d64523a79fc50638fdf378b5d521a5cd4482b90..dddc72480f93c4c3cc29f41db99fa773dc8d6b68 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -54,6 +54,7 @@ namespace xla { V(kBitcast, "bitcast") \ V(kBitcastConvert, "bitcast-convert") \ V(kBroadcast, "broadcast") \ + V(kBroadcastDimOne, "broadcast-dim-one") \ V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ @@ -76,9 +77,11 @@ namespace xla { V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIsFinite, "is-finite") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 68e3c9618c1fe9daacb0aee3ee98862c8b9e4bc4..e89d94bede6c437ca1131a1b1b0098390d58c0d9 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -66,6 +66,28 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } } + // If the common ancestor is a conditional instruction, even though the true + // and false computations are not really ordered per-se, we define the true + // computation to be ordered before the false one. + // This ensures that buffers can still be shared among the two computations + // as they will forcibly have disjoint liveness. + if (a_ancestor == b_ancestor && + a_ancestor->opcode() == HloOpcode::kConditional) { + const HloComputation* true_computation = a_ancestor->true_computation(); + const HloComputation* false_computation = a_ancestor->false_computation(); + if (call_graph_->InstructionIsNestedIn(a, true_computation) && + call_graph_->InstructionIsNestedIn(b, false_computation)) { + return true; + } + // If 'b' is the conditional ancestor, and 'a' is within the true or false + // computations, 'a' executes before 'b'. + if (b == a_ancestor && + (call_graph_->InstructionIsNestedIn(a, true_computation) || + call_graph_->InstructionIsNestedIn(a, false_computation))) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } @@ -118,7 +140,18 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { b.defining_instruction()->while_condition()))) { return true; } - + // If 'b' is a conditional phi and 'a' is in the true or false computation, + // then 'a' executes before 'b'. + if (b.is_phi() && + b.defining_instruction()->opcode() == HloOpcode::kConditional && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->true_computation()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->false_computation()))) { + return true; + } return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); } @@ -186,6 +219,22 @@ bool HloOrdering::UseIsBeforeValueDefinition( } } + if (use.instruction->opcode() == HloOpcode::kConditional) { + const HloInstruction* conditional = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + conditional->true_computation())) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in TRUE computation"; + return true; + } + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + conditional->false_computation())) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in FALSE computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } @@ -196,18 +245,17 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() << ", b = " << b.ToShortString() << ")"; if (!IsDefinedBefore(a, b)) { - VLOG(4) << "a not defined before b"; + VLOG(4) << a << " not defined before " << b; return false; } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { - VLOG(4) << "use of a (" << use << ") not before b is defined"; + VLOG(4) << "use of " << a << " (" << use << ") not before " << b + << " is defined"; return false; } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index aba66114de649ce7667ae77174e9c4073b010b90..37a7fbad97cea2f34798efecc2489e57d1374f35 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -34,53 +34,6 @@ namespace { class HloOrderingTest : public HloTestBase {}; -TEST_F(HloOrderingTest, LastUseScheduledFirst) { - // Tests scheduling of the following HLO code: - // - // %ab = abs(%param) - // %exp = exp(%param) - // %add = add(%ab, %exp) - // %negate = negate(%exp) - // %sub = subtract(%add, %negate) - // - // %add should be scheduled before %negate because %add is the last (and only) - // use of %ab. Scheduling %add first then frees up %ab's buffer. - const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); - auto builder = HloComputation::Builder(TestName()); - auto param = - builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); - auto ab = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); - auto sub = builder.AddInstruction( - HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - - // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - - SequentialHloOrdering ordering(module.get(), sequence); - EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); -} - TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // Tests the ordering of instructions in different computations using the // following HLO code: @@ -262,8 +215,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { scalar_shape, HloOpcode::kAdd, constant, xla_while)); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); // Init value is defined before the while, but live range is not before the @@ -362,5 +315,66 @@ ENTRY while.v11 { ordering.ToString(); // Shouldn't crash. } +TEST_F(HloOrderingTest, ConditionalInstructionOrdering) { + const char* module_str = R"( +HloModule test_conditional_module + +true_branch { + param.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(param.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(param.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + ROOT tuple.1 = (s32[], s32[]) tuple(add.1, get-tuple-element.1) +} + +false_branch { + param.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(param.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(param.2), index=1 + add.2 = s32[] add(get-tuple-element.3, get-tuple-element.4) + ROOT tuple.2 = (s32[], s32[]) tuple(add.2, get-tuple-element.4) +} + +ENTRY root { + param.3 = (pred[], (s32[], s32[])) parameter(0) + pred.1 = pred[] get-tuple-element(param.3), index=0 + cond_arg.1 = (s32[], s32[]) get-tuple-element(param.3), index=1 + conditional = (s32[], s32[]) conditional(pred.1, cond_arg.1, cond_arg.1), true_computation=true_branch, false_computation=false_branch + cond_res.1 = s32[] get-tuple-element(conditional), index=0 + cond_res.2 = s32[] get-tuple-element(conditional), index=1 + add.3 = s32[] add(cond_res.1, cond_res.2) + ROOT result = (s32[], s32[], s32[]) tuple(add.3, cond_res.1, cond_res.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + + // Even though the true and false branches has no ordering, since they do not + // interfere (as they are mutually exclusive), we define the true computation + // to be before the false one. + // Similarly, any instruction in the true or false branches are considered + // before the conditional instruction. The roots are effectively "at the same + // time" WRT the conditional, but they are Phi-ed anyway. + HloInstruction* add_1 = FindInstruction(module.get(), "add.1"); + HloInstruction* add_2 = FindInstruction(module.get(), "add.2"); + HloInstruction* add_3 = FindInstruction(module.get(), "add.3"); + HloInstruction* conditional = FindInstruction(module.get(), "conditional"); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_2))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE( + ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(conditional))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_1), + dataflow->GetValueDefinedAt(add_3))); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(add_2), + dataflow->GetValueDefinedAt(add_3))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 78e6a101c10a1e812e3e2631d520139fd0bc425c..3460679558d185d1e022660d9a1d23176d0d96bf 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include + +#include "tensorflow/compiler/xla/util.h" + namespace xla { HloProto MakeHloProto(const HloModule& module, @@ -35,4 +39,35 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + + std::vector parameter_shapes; + const auto& program_shape = hlo_proto.hlo_module().program_shape(); + for (const Shape& shape : program_shape.parameters()) { + parameter_shapes.push_back(&shape); + } + return parameter_shapes; +} + +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto) { + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + if (!hlo_proto.hlo_module().program_shape().has_result()) { + return NotFound("HloProto missing result in its program shape"); + } + + return &hlo_proto.hlo_module().program_shape().result(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 320288fdb9aa0810b306b1d78bd1ff4cfc366ed2..3d9c375cd5d26f92cf8316f78789daf4fc08c927 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,15 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Returns the shapes of the parameters of the entry computation. Shape pointers +// refer to shapes inside of the given HloProto. +StatusOr> EntryComputationParameterShapes( + const HloProto& hlo_proto); + +// Returns the shape of the output of the entry computation. The shape pointer +// refers to the output shape inside of the given HloProto. +StatusOr EntryComputationOutputShape(const HloProto& hlo_proto); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b9cca138703c8fa61aadf69dd7304a215a9f4be2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace { + +class HloProtoUtilTest : public ::testing::Test {}; + +TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) { + HloProto hlo_proto; + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing HloModuleProto")); +} + +TEST_F(HloProtoUtilTest, MissingProgramShape) { + HloProto hlo_proto; + HloModuleProto* module = hlo_proto.mutable_hlo_module(); + module->set_name("entry"); + + auto status = EntryComputationParameterShapes(hlo_proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + ::testing::HasSubstr("missing program shape")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c6b4dc0368d92fd477decdfb38045f74f8696803..b0632448933df4b7681a0704c58d697b5ec68a1f 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -60,6 +60,7 @@ bool IsRematerializable(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: + case HloOpcode::kConditional: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kParameter: @@ -1319,7 +1320,7 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SchedulerAlgorithm scheduler_algorithm, + MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { HloRematerialization remat(scheduler_algorithm, size_function); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 52553439033a3bcfa4b472f13f9cd4b1ecf5ed96..2ee2dd0571ae8c6604e4ca722351fd48a913bda5 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -66,12 +66,12 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, const ShapeSizeFunction& size_function) : scheduler_algorithm_(scheduler_algorithm), size_function_(size_function) {} @@ -108,7 +108,7 @@ class HloRematerialization { const HloInstruction* instruction) const; // Selects an algorithm to use for HLO scheduling. - SchedulerAlgorithm scheduler_algorithm_; + MemorySchedulerAlgorithm scheduler_algorithm_; // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 1b7d26dde501a6a0955d62ea0938e0683a32d49d..83de54f3fa56ee660b79d8c366dbc0b52f9fde87 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -162,7 +162,7 @@ TEST_F(HloRematerializationTest, SingleComputation) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/14 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -195,7 +195,7 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/20 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -236,7 +236,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/17 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -272,7 +272,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/15 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -314,7 +314,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/13 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { bool changed, HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), SchedulerAlgorithm::kAuto, &sequence)); + module.get(), DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -480,7 +480,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -577,7 +577,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { HloRematerialization::RematerializeAndSchedule( ByteSizeOf, /*memory_limit_bytes=*/22 * 1024, module.get(), - SchedulerAlgorithm::kAuto, &sequence)); + DefaultMemoryScheduler, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 41b079eb799d06321a31f7d7ae0630dc8d58c46b..2e834a79d9f63154172798d252be938d0d475c01 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -16,21 +16,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_runner.h" -#include #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -52,10 +47,9 @@ namespace { // Creates an HloModule from the given proto. StatusOr> HloProtoToModule( const HloProto& proto, const DebugOptions& debug_options) { - TF_ASSIGN_OR_RETURN( - HloModuleConfig config, - HloModule::CreateModuleConfigFromProto(proto.hlo_module())); - config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto.hlo_module(), + debug_options)); TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(proto.hlo_module(), config)); return std::move(module); @@ -92,15 +86,6 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, return tools::Parse(hlo_string, config); } -// Define this in .cc file to avoid having to include eigen or forward declare -// these types in the header. -struct HloRunner::EigenThreadPoolWrapper { - std::unique_ptr pool; - std::unique_ptr device; -}; - -HloRunner::HloRunner() {} - HloRunner::HloRunner(se::Platform* platform) { BackendOptions backend_options; backend_options.set_platform(platform); @@ -110,36 +95,18 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::ExecuteInternal( +StatusOr> HloRunner::Execute( std::unique_ptr module, const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes) { - if (run_hlo_passes) { - TF_ASSIGN_OR_RETURN( - module, backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr)); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor(), - /*device_allocator=*/nullptr)); - + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); se::Stream stream(backend().default_stream_executor()); stream.Init(); - ExecutableRunOptions run_options; - run_options.set_device_ordinal(backend().default_device_ordinal()); - run_options.set_stream(&stream); - run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - backend().eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_run_options( - run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); + ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice( + backend().default_device_ordinal(), &stream, nullptr)); + const ExecutableRunOptions& run_options = service_run_options.run_options(); // Copy arguments to device. std::vector> argument_buffers; @@ -158,8 +125,8 @@ StatusOr> HloRunner::ExecuteInternal( TF_ASSIGN_OR_RETURN( std::unique_ptr result, - executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, - /*hlo_execution_profile=*/nullptr)); + executable->ExecuteOnStreamWrapper( + &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); // Create a ScopedShapedBuffer of the result to manage deallocation. This will // deallocate all the device memory when it goes out of scope. @@ -179,10 +146,153 @@ StatusOr> HloRunner::ExecuteInternal( return result_literal; } +StatusOr>> HloRunner::ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + CreateExecutable(std::move(module), options.run_hlo_passes)); + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + std::vector> streams; + std::vector service_run_options; + std::vector> argument_buffers; + // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are + // no arguments. + std::vector argument_buffer_ptrs( + options.num_replicas * options.arguments.size() + 1); + std::vector> + argument_buffer_slices; + int64 index = 0; + for (int64 i = 0; i < options.num_replicas; ++i) { + int64 device = device_assignment(i, 0); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device)); + streams.push_back(absl::make_unique(executor)); + streams.back()->Init(); + service_run_options.emplace_back(GetServiceRunOptionsForDevice( + device, streams.back().get(), &device_assignment)); + + // Copy arguments to device. + for (const Literal* argument : options.arguments) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr argument_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + argument->shape(), backend().memory_allocator(), device)); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + executor, *argument, *argument_buffer)); + argument_buffers.push_back(std::move(argument_buffer)); + argument_buffer_ptrs[index++] = argument_buffers.back().get(); + } + argument_buffer_slices.emplace_back( + &argument_buffer_ptrs[index - options.arguments.size()], + options.arguments.size()); + } + + std::unique_ptr pool; + int64 num_threads = (options.infeed != nullptr) ? options.num_replicas : 0; + if (ShapeUtil::IsInitialized(options.outfeed_shape)) { + num_threads += options.num_replicas; + } + if (num_threads > 0) { + pool = absl::make_unique( + tensorflow::Env::Default(), "infeed_outfeed", + /*num_threads=*/num_threads); + } + if (options.infeed != nullptr) { + for (int64 i = 0; i < options.num_replicas; ++i) { + int64 device = device_assignment(i, 0); + pool->Schedule([this, device, &options]() { + se::StreamExecutor* executor = + backend().stream_executor(device).ValueOrDie(); + VLOG(1) << "Starting infeed on device " << device; + for (int64 step = 1; + options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToInfeed( + executor, *options.infeed)); + if (step % 100 == 0) { + VLOG(1) << "Infeed step " << step; + } + } + }); + } + } + if (ShapeUtil::IsInitialized(options.outfeed_shape)) { + for (int64 i = 0; i < options.num_replicas; ++i) { + int64 device = device_assignment(i, 0); + pool->Schedule([this, device, &options]() { + se::StreamExecutor* executor = + backend().stream_executor(device).ValueOrDie(); + VLOG(1) << "Starting outfeed on device " << device; + for (int64 step = 1; + options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { + auto literal = absl::make_unique(); + TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( + executor, options.outfeed_shape, literal.get())); + if (options.outfeed_values != nullptr) { + options.outfeed_values->push_back(std::move(literal)); + } + if (step % 100 == 0) { + VLOG(1) << "Outfeed step " << step; + } + } + }); + } + } + + LOG(INFO) << "Replicated execution started"; + TF_ASSIGN_OR_RETURN(std::vector> results, + executable->ExecuteOnStreams(service_run_options, + argument_buffer_slices)); + LOG(INFO) << "Replicated execution terminated"; + + std::vector> exec_results; + for (int64 i = 0; i < options.num_replicas; ++i) { + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + ScopedShapedBuffer::MakeScoped( + results[i].get(), backend().memory_allocator())); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + backend().transfer_manager()->TransferLiteralFromDevice( + streams[i]->parent(), *result)); + exec_results.push_back(std::move(literal)); + } + return std::move(exec_results); +} + +StatusOr> HloRunner::CreateExecutable( + std::unique_ptr module, bool run_hlo_passes) { + if (run_hlo_passes) { + TF_ASSIGN_OR_RETURN( + module, backend().compiler()->RunHloPasses( + std::move(module), backend().default_stream_executor(), + backend().memory_allocator())); + } + return backend().compiler()->RunBackend(std::move(module), + backend().default_stream_executor(), + backend().memory_allocator()); +} + +ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( + int64 device, se::Stream* stream, DeviceAssignment* device_assignment) { + ExecutableRunOptions run_options; + run_options.set_device_ordinal(device); + run_options.set_stream(stream); + run_options.set_allocator(backend().memory_allocator()); + run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + backend().eigen_intra_op_thread_pool_device()); + if (device_assignment != nullptr) { + run_options.set_device_assignment(device_assignment); + } + return ServiceExecutableRunOptions(run_options, backend().StreamBorrower(), + backend().inter_op_thread_pool()); +} + Backend& HloRunner::backend() { if (!backend_) { backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); - VLOG(1) << "executing on platform " << backend().platform()->Name(); + VLOG(1) << "Executing on platform " << backend().platform()->Name(); } return *backend_; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index cbaebc68bee708090b8ccb2eae19b556c4d6d453..f54fb44766eb07f402b2946abc83d50d155e47c1 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -16,17 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ +#include #include +#include #include #include #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,9 +44,43 @@ namespace xla { // file), or parsed from a hlo textual IR string. class HloRunner { public: - HloRunner(); - - HloRunner(::perftools::gputools::Platform* platform); + // The options used to configure a ExecuteReplicated() call. + struct ReplicatedExecuteOptions { + // The number of devices the HLO module should be replicated onto. + int64 num_replicas = 1; + + // The arguments to be fed to each replica. Since this is used for a + // replicated execution, all the arguments are the same for all replicas. + std::vector arguments; + + // If the HLO module being run has an infeed instruction, this will be the + // data which will be fed to it, for as many as infeed_steps steps. + const Literal* infeed = nullptr; + + // The number of times the infeed literal should be fed to the HLO module. + // For a clean exit, this should match the iterations-per-loop parameter + // used when generating the HLO module proto (that is usually the main + // while bounary counter). A value higher then iterations-per-loop would + // lead to infeed threads feeding to a gone computation, while a lower + // value would trigger a stuck ExecuteReplicated() call (the computation + // will be trying to infeed data which will never come). + int64 infeed_steps = -1; + + // The shape of the outfeed operation. If empty, the HLO module does not + // generate any outfeed. + Shape outfeed_shape; + + // A pointer to a vector where the outfeed values will be stored. If + // nullptr, the values will be read and discarded. + std::vector>* outfeed_values = nullptr; + + // Whether the HLO passes should be run on the input module. Usually + // saved modules are coming from after the HLO pass pipeline, so triggering + // another run will likely cause errors. + bool run_hlo_passes = false; + }; + + explicit HloRunner(::perftools::gputools::Platform* platform); ~HloRunner(); @@ -64,17 +103,34 @@ class HloRunner { const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the - // result as a Literal. The LiteralPtr type accepts Literal* or - // std::unique_ptr. + // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes = true); + StatusOr> Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes = true) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + c_transform( + arguments, std::back_inserter(argument_pointers), + [](const std::unique_ptr& literal) { return literal.get(); }); + return Execute(std::move(module), argument_pointers, run_hlo_passes); + } + + // Executes a given HLO module into a set of replicas, and returns a map + // with the replica number as key, and the corresponding returned literal as + // value. + StatusOr>> ExecuteReplicated( + std::unique_ptr module, + const ReplicatedExecuteOptions& options); + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // @@ -83,31 +139,22 @@ class HloRunner { Backend& backend(); private: - StatusOr> ExecuteInternal( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); - - struct EigenThreadPoolWrapper; - - std::unique_ptr thread_pool_wrapper_; + // Creates an executable object given an HLO module. If run_hlo_passes is + // true, the HLO passes will be run before. + StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes); + + // Creates a ServiceExecutableRunOptions object to configure a run on device, + // using the provided stream object. If device_assignment is not nullptr, it + // will be used to configure the replication parameters. Replicated executions + // should pass the device_assignment parameter. + ServiceExecutableRunOptions GetServiceRunOptionsForDevice( + int64 device, ::perftools::gputools::Stream* stream, + DeviceAssignment* device_assignment); std::unique_ptr backend_; }; -template -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - for (const auto& argument : arguments) { - argument_pointers.push_back(&*argument); - } - return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8dc4d4f7bac1b2007f2b9f60d126fa07e314dac9..1a767628f6e2d33df353366974fb866e89f0df5a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include +#include #include #include @@ -103,10 +103,11 @@ class ListScheduler { for (auto* instruction : computation.instructions()) { tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - instr_uses.insert(buffer); - } + points_to_analysis.GetPointsToSet(operand).ForEachElement( + [&](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + instr_uses.insert(buffers.begin(), buffers.end()); + }); } buffer_uses_[instruction] = std::vector( instr_uses.begin(), instr_uses.end()); @@ -151,8 +152,10 @@ class ListScheduler { int64 bytes_defined; // For each buffer B used by this instruction, we keep a pair (B, U), where - // U is the number of uses of B that have not yet been scheduled. - std::vector> + // U is the number of uses of B that have not yet been scheduled. This pair + // is a pointer into the unscheduled_use_count_ map, so it gets updated for + // free when we update counts in the map. + std::vector*> used_buffer_unscheduled_use_counts; }; @@ -175,8 +178,8 @@ class ListScheduler { } auto unscheduled_use_count_it = unscheduled_use_count_.find(buffer); CHECK(unscheduled_use_count_it != unscheduled_use_count_.end()); - entry.used_buffer_unscheduled_use_counts.emplace_back( - unscheduled_use_count_it->first, unscheduled_use_count_it->second); + entry.used_buffer_unscheduled_use_counts.push_back( + &*unscheduled_use_count_it); } return entry; } @@ -185,8 +188,8 @@ class ListScheduler { int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { - auto buffer = kv.first; - auto use_count = kv.second; + auto buffer = kv->first; + auto use_count = kv->second; if (use_count == 1) { freed_bytes += size_function_(*buffer); } @@ -217,23 +220,18 @@ class ListScheduler { } } - auto priority_comparator = - [this](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }; - std::priority_queue, - std::vector>, - decltype(priority_comparator)> - ready_queue(priority_comparator); + // Use a multimap to sort ReadyListEntry according to their priority. + std::multimap ready_queue; - // Set of instructions in the ready list. - tensorflow::gtl::FlatSet ready_instructions; + // Map of ready instructions to their iterators in ready_queue. + tensorflow::gtl::FlatMap::iterator> + ready_instructions; auto add_to_ready_queue = [&](HloInstruction* inst) { auto entry = MakeReadyListEntry(inst); - ready_queue.emplace(GetPriority(entry), std::move(entry)); - ready_instructions.insert(inst); + auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); + ready_instructions[inst] = it; }; for (auto* instruction : computation_.instructions()) { @@ -247,14 +245,10 @@ class ListScheduler { while (!ready_queue.empty()) { // Remove the selected instruction from the ready list and add it to the // schedule. - const HloInstruction* best = ready_queue.top().second.instruction; - ready_queue.pop(); - // We may have duplicates in the priority queue, because when a ready - // instruction's priority goes up, we reinsert it to the priority queue. - // Skip the duplicate. - if (scheduled_instructions_.find(best) != scheduled_instructions_.end()) { - continue; - } + auto best_it = ready_queue.end(); + --best_it; + const HloInstruction* best = best_it->second.instruction; + ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); scheduled_instructions_.insert(best); @@ -287,16 +281,27 @@ class ListScheduler { update_pred_count(succ); } // The unscheduled use count for a buffer has changed to 1, so the - // priorities of some ready instructions may go up. We reinsert them to - // the priority queue, so that they can appear earlier. The old entries - // will become duplicates and will be skipped. + // priorities of some ready instructions may go up. We update them in the + // ready queue, so that they can appear earlier. if (adjust_ready_queue) { for (HloInstruction* operand : best->operands()) { for (HloInstruction* operand_user : operand->users()) { - if (ready_instructions.find(operand_user) != - ready_instructions.end()) { - add_to_ready_queue(operand_user); + auto ready_instructions_it = ready_instructions.find(operand_user); + if (ready_instructions_it == ready_instructions.end()) { + continue; } + auto ready_queue_it = ready_instructions_it->second; + auto& entry = ready_queue_it->second; + Priority new_priority = GetPriority(entry); + if (new_priority == ready_queue_it->first) { + continue; + } + // Create a new entry in ready_queue, then update + // ready_instructions[operand_user] to refer to the new entry. + ready_instructions_it->second = + ready_queue.emplace(new_priority, std::move(entry)); + // Remove the old entry in ready_queue. + ready_queue.erase(ready_queue_it); } } } @@ -317,8 +322,9 @@ class ListScheduler { buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map. - tensorflow::gtl::FlatMap unscheduled_use_count_; + // LogicalBuffer. We rely on iterator stability in this map, and that the map + // entries are std::pair's. + std::unordered_map unscheduled_use_count_; // Set of instructions which have been scheduled. tensorflow::gtl::FlatSet scheduled_instructions_; @@ -334,7 +340,33 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> RunDFSMemoryScheduler( +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function); +} + +} // namespace + +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { @@ -343,6 +375,7 @@ StatusOr> RunDFSMemoryScheduler( // simply users-1 for each instruction. By subtracting 1, we're saying that // instructions with no users or a single user don't count; instructions with // lots of fan-out will be visited earlier. + int64 cumulative_total_size = 0; tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -352,14 +385,17 @@ StatusOr> RunDFSMemoryScheduler( continue; } extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - total_sizes[hlo] = SumLogicalBufferSizes( + int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + total_sizes[hlo] = logical_buffer_size; + cumulative_total_size += logical_buffer_size; tensorflow::gtl::FlatSet unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -387,32 +423,17 @@ StatusOr> RunDFSMemoryScheduler( return sequence; } -StatusOr MinimumMemoryForComputation( +StatusOr> ListMemoryScheduler( const HloComputation& computation, - const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; + return ListScheduler::Run(computation, points_to_analysis, size_function); } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm == SchedulerAlgorithm::kListSchedule) { - return ListScheduler::Run(computation, points_to_analysis, size_function); - } - if (algorithm == SchedulerAlgorithm::kDfsSchedule) { - return RunDFSMemoryScheduler(computation, points_to_analysis, - size_function); - } - + const LogicalBuffer::SizeFunction& size_function) { // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -422,7 +443,7 @@ StatusOr> CreateMemoryMinimizingSequence( // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, @@ -431,7 +452,7 @@ StatusOr> CreateMemoryMinimizingSequence( TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + DFSMemoryScheduler(computation, points_to_analysis, size_function)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, @@ -449,12 +470,10 @@ StatusOr> CreateMemoryMinimizingSequence( } } -} // namespace - StatusOr CreateMemoryMinimizingSequence(const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -470,7 +489,7 @@ CreateMemoryMinimizingSequence(const HloModule& module, StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm) { + const MemorySchedulerAlgorithm& algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 1d1eb1e064f75c2220b39e84b010e720a0c37880..068e68383deb170ded1c9b09a8b7ceb8c4c0ab4b 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,28 +34,48 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); -enum class SchedulerAlgorithm { - kListSchedule, - kDfsSchedule, +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory, given a points-to analysis result +// that describes buffer aliasing, together with a target-specific size function +// that maps a tensor's logical size to its padded size. +typedef std::function>( + const HloComputation&, const TuplePointsToAnalysis&, + const LogicalBuffer::SizeFunction&)> + MemorySchedulerAlgorithm; - // Selects the available scheduler algorithm that had the minimum memory in - // the resulting sequence (a la MinimumMemoryForSequence). - kAuto, -}; +// List scheduler +StatusOr> ListMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// DFS-order scheduler +StatusOr> DFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + +// The default scheduling algorithm. Runs both the list scheduler +// and the DFS scheduler, and chooses whichever returns a lower min-memory, +// not accounting for fragmentation. +StatusOr> DefaultMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function, - SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); + const MemorySchedulerAlgorithm& algorithm = {}); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 7fb338e7042ce19ac9647e23719e738f3ef42c7c..74544c4a67a819d341056aba4cf6b321a5a86c0a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -89,5 +90,105 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); } +class HloSchedulingTest : public HloTestBase {}; + +TEST_F(HloSchedulingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); +} + +TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { + const char* module_str = R"( +HloModule test_aliasing_module + +ENTRY root { + param = s32[1000] parameter(0) + p0 = s32[1000] copy(param) + p1 = s32[1000] copy(param) + t = (s32[1000], s32[1000]) tuple(p0, p1) + a = s32[1000] get-tuple-element(t), index=0 + b = s32[1000] get-tuple-element(t), index=1 + c = s32[1000] add(a, b) + d = s32[1000] add(c, b) + e = s32[1000] add(c, c) + f = s32[1000] add(e, e) + ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + + std::unordered_map instructions_by_name; + for (const HloInstruction* instruction : + sequence.at(module->entry_computation())) { + instructions_by_name[instruction->name()] = instruction; + } + + // The first instruction should be the parameter and the last the root. + EXPECT_EQ(instructions_by_name.at("param"), + sequence.at(module->entry_computation()).front()); + EXPECT_EQ(instructions_by_name.at("result"), + sequence.at(module->entry_computation()).back()); + + // Instructions "d" and "e" will both be schedulable at the same time, but + // instruction "d" allows us to free the buffer of "p1", so the list scheduler + // should prefer it. + SequentialHloOrdering ordering(module.get(), sequence); + EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), + instructions_by_name.at("e"))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 447c2446668253c932b44b51b2db22bfd47f9957..1b42349b0b3ad9634bb910b3843affed6a0ca334 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -20,6 +20,7 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrCat; HloSharding HloSharding::AssignDevice(int64 device_id) { @@ -57,8 +58,9 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", - "devices=", VectorString(tile_assignment_), "}"); + return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", "devices=[", + Join(tile_assignment_.dimensions(), ","), "]", + Join(tile_assignment_, ","), "}"); } } @@ -183,6 +185,10 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { // shape tree. ShapeTree shape_tree = GetAsShapeTree(shape); for (const auto& index_to_sharding : shape_tree.leaves()) { + if (index_to_sharding.first.empty()) { + // An empty tuple has a ShapeTree with a single leaf at the empty index. + continue; + } Status status = index_to_sharding.second.ValidateNonTuple( ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); if (!status.ok()) { @@ -222,7 +228,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, uint32 core) { + [&](tensorflow::gtl::ArraySlice indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { @@ -344,4 +350,45 @@ OpSharding HloSharding::ToProto() const { return result; } +HloSharding HloSharding::TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform) const { + CHECK(!IsTuple()); + if (IsTileMaximal()) { + return *this; + } + CHECK_EQ(ShapeUtil::Rank(new_shape), ShapeUtil::Rank(tile_shape())); + Shape new_tile_shape; + new_tile_shape.set_element_type(tile_shape().element_type()); + for (int64 i = 0; i < ShapeUtil::Rank(new_shape); ++i) { + int64 dim; + if (tile_assignment().dim(i) == 1) { + dim = new_shape.dimensions(i); + } else if (transform) { + dim = transform(i, tile_shape().dimensions(i)); + } else { + dim = tile_shape().dimensions(i); + } + new_tile_shape.add_dimensions(dim); + } + TF_CHECK_OK( + LayoutUtil::CopyLayoutBetweenShapes(tile_shape_, &new_tile_shape)); + return HloSharding::Tile(new_tile_shape, tile_assignment()); +} + +HloSharding HloSharding::GetSubSharding(const Shape& shape, + const ShapeIndex& index) const { + CHECK(IsTuple()); + + ShapeTree sub_shape_tree(ShapeUtil::GetSubshape(shape, index), + Replicate()); + sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); + return Tuple(sub_shape_tree); +} + +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { + out << sharding.ToString(); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 7263198385cf0c84b1dac1e15177dcac99adaafb..2b8e757f42991f697df37d3d34bfdff6a36bc509 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -94,6 +94,10 @@ class HloSharding { // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); + // Checks whether device is a reserved device number. A reserved device number + // has usually a special meaning, with dedicated handling logic. + static bool IsReservedDevice(int64 device) { return device < 0; } + OpSharding ToProto() const; string ToString() const; @@ -171,9 +175,13 @@ class HloSharding { } } + // Retrieves the sub sharding at a given index, out of a tuple sharding. + // REQUIRES: IsTuple() + HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && - protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && + ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && tile_assignment_ == other.tile_assignment_ && tuple_elements_ == other.tuple_elements_; } @@ -207,6 +215,26 @@ class HloSharding { // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } + // Returns the flattened list of all the leaf shardings in a tuple shape, by + // pre-order walk (ShapeTree iterator order). + // REQUIRES: IsTuple(). + const std::vector& tuple_elements() const { + return tuple_elements_; + } + + // Return a new sharding that can apply to the given new shape. + // If this sharding is tile-maximal, the returned sharding will be the same as + // this sharding. If this sharding is not tile-maximal, the returned + // sharding's tile size will differ: + // - Non-sharded dimensions will be adapted to be the same as `new_shape`; + // tile_dimension(i) = new_shape.dimensions(i); + // - Sharded dimensions will be kept the same unless `transform` is supplied + // in which case tile_dimension(i) = transform(i, tile_dimension(i)); + // REQUIRES: !IsTuple(). + HloSharding TransformShardedTileShape( + const Shape& new_shape, + const std::function& transform = nullptr) const; + private: HloSharding() : replicated_(true), @@ -249,6 +277,8 @@ class HloSharding { std::vector tuple_elements_; }; +std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 0c7487b3ac77ff181d44dd55ebcf2608feaf02ea..69ea4233e45c2e59c8d1541a0517a007f4bbf42f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -269,5 +269,57 @@ TEST_F(HloShardingTest, Hash) { } } +TEST_F(HloShardingTest, TransformShardedTileShapeTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + HloSharding result = sharding.TransformShardedTileShape( + ShapeUtil::MakeShape(F32, {13, 15, 17, 19}), + [](int dim, int value) { return dim * 111; }); + HloSharding expected = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {13, 15, 222, 333}), + Array4D({{{{0, 1}, {2, 3}}}})); + EXPECT_EQ(result, expected); +} + +TEST_F(HloShardingTest, ToStringReplicatedTest) { + HloSharding sharding = HloSharding::Replicate(); + EXPECT_EQ(sharding.ToString(), "{replicated}"); +} + +TEST_F(HloShardingTest, ToStringAssignDeviceTest) { + HloSharding sharding = HloSharding::AssignDevice(7); + EXPECT_EQ(sharding.ToString(), "{maximal device=7}"); +} + +TEST_F(HloShardingTest, ToStringTiledTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(S32, {7, 11, 13}), + Array3D({{{2, 3}}, {{5, 7}}})); + EXPECT_EQ(sharding.ToString(), "{s32[7,11,13] devices=[2,1,2]2,3,5,7}"); +} + +TEST_F(HloShardingTest, ToStringTupleTest) { + HloSharding sharding = HloSharding::Tuple( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5}), + ShapeUtil::MakeShape(U32, {7, 25}), + ShapeUtil::MakeShape(S32, {9, 11})}), + {HloSharding::Replicate(), + HloSharding::Tile(ShapeUtil::MakeShape(U32, {7, 13}), + Array2D({{3, 5}})), + HloSharding::AssignDevice(3)}); + EXPECT_EQ(sharding.ToString(), + "{{replicated}, {u32[7,13] devices=[1,2]3,5}, {maximal device=3}}"); +} + +TEST_F(HloShardingTest, OstreamTest) { + HloSharding sharding = + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 5, 7, 11}), + Array4D({{{{0, 1}, {2, 3}}}})); + std::ostringstream oss; + oss << sharding; + EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e2b3bb9d71497c352b0b92add2d2f6b4b777bee8..63ec5964eb935239e86233c1ae94e2bcce3b0461 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -125,6 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { return CheckShape(outfeed, ShapeUtil::MakeNil()); } +Status ShapeVerifier::HandleHostCompute(HloInstruction*) { + return tensorflow::Status::OK(); +} + Status ShapeVerifier::HandleRng(HloInstruction*) { return tensorflow::Status::OK(); } @@ -170,17 +174,34 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); - ++operand_dimension) { - int64 output_dimension = broadcast->dimensions()[operand_dimension]; + for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { + int64 output_dimension = broadcast->dimensions()[i]; TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)) + operand_shape.dimensions(i)) << broadcast->ToString() << " operand shape " << operand_shape; } return tensorflow::Status::OK(); } +Status ShapeVerifier::HandleBroadcastDimOne(HloInstruction* broadcastDimOne) { + const Shape& operand_shape = broadcastDimOne->operand(0)->shape(); + int64 operand_rank = ShapeUtil::Rank(operand_shape); + const Shape& output_shape = broadcastDimOne->shape(); + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(broadcastDimOne, output_shape)); + TF_RET_CHECK(operand_rank == ShapeUtil::Rank(output_shape)); + for (int64 i = 0; i < operand_rank; ++i) { + int64 operand_dimension = operand_shape.dimensions(i); + int64 output_dimension = output_shape.dimensions(i); + TF_RET_CHECK(operand_dimension == 1 || + operand_dimension == output_dimension) + << "Dimension " << i << " of broadcastDimOne " + << broadcastDimOne->ToString() << " is " << operand_dimension + << ", expected 1 or " << output_dimension; + } + return tensorflow::Status::OK(); +} + Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { // Check for mixed precision. TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); @@ -420,6 +441,14 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace +Status ShapeVerifier::HandleGather(HloInstruction* gather) { + return CheckShape( + gather, + ShapeInference::InferGatherShape( + gather->operand(0)->shape(), gather->operand(1)->shape(), + gather->gather_dimension_numbers(), gather->gather_window_bounds())); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -750,11 +779,14 @@ StatusOr HloVerifier::Run(HloModule* module) { } else if (instruction->opcode() == HloOpcode::kBroadcast) { // If you see this failure then someone has confused the difference // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I // or ComputationLowerer::Visit() TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO has invalid number of dimensions."; + << "Broadcast HLO (" << instruction->ToShortString() + << ") has invalid number of dimensions: " + << instruction->dimensions().size() + << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { auto* while_cond = instruction->while_condition(); auto* while_body = instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7eccf834bbd3ac6af0d5762a7241758b416a3523..a4dff977ba268137d8ab94c576b4b511e911806f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -54,12 +54,14 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleReduce(HloInstruction* reduce) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleBroadcastDimOne(HloInstruction* broadcastDimOne) override; Status HandleReshape(HloInstruction* reshape) override; Status HandleTranspose(HloInstruction* transpose) override; Status HandleParameter(HloInstruction*) override; Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; + Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( @@ -79,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormInference( HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + Status HandleGather(HloInstruction* gather) override; Status FinishVisit(HloInstruction*) override { return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 90e1f0acdc4cdeda280dabaab2df66b181d0f407..3f4dbf897df7e1fd62f4229ed90c949c59da9d46 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -37,6 +37,7 @@ namespace xla { case HloOpcode::kBitcast: case HloOpcode::kBitcastConvert: case HloOpcode::kBroadcast: + case HloOpcode::kBroadcastDimOne: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kComplex: @@ -102,6 +103,8 @@ namespace xla { case HloOpcode::kExp: case HloOpcode::kFft: case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kMap: case HloOpcode::kParameter: @@ -140,7 +143,8 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { }); return std::count_if(hlo->operands().begin(), hlo->operands().end(), [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kBroadcastDimOne) { return false; } if (operand->opcode() == HloOpcode::kConstant && @@ -245,7 +249,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { auto reachability = computation->ComputeReachability(); auto cheap_to_duplicate = [this](HloInstruction* producer) { - if (producer->opcode() == HloOpcode::kBroadcast) { + if (producer->opcode() == HloOpcode::kBroadcast || + producer->opcode() == HloOpcode::kBroadcastDimOne) { return true; } if (producer->opcode() == HloOpcode::kConstant && @@ -300,7 +305,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { // Consider each operand of this instruction for fusion into this // instruction. We want to consider the operands in a particular order to - // avoid created duplicate instruction clones in the fusion instruction. + // avoid creating duplicate instruction clones in the fusion instruction. // For example, consider the following expression: // // A = ... @@ -375,7 +380,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting it's + // Operand is now dead. Remove from post order by setting its // location to nullptr. post_order[FindOrDie(post_order_index, operand)] = nullptr; post_order_index.erase(operand); diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 0819ab3b90b2360c6b0b2afaa89f322afe566eb3..45505484951abfcee93a62fec7a99e86cbb9150c 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -63,10 +63,7 @@ cc_library( name = "platform_id", srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], - deps = [ - "@nsync//:nsync_headers", - "//tensorflow/core:stream_executor_headers_lib", - ] + if_static( + deps = ["//tensorflow/core:stream_executor_headers_lib"] + if_static( ["@protobuf_archive//:protobuf"], ["@protobuf_archive//:protobuf_headers"], ), @@ -123,14 +120,3 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9171e859c6f84ceef9664aa1eb90a07c87dfab40..5b9bf5faf366d674ecadd59fa8a0af8d4976a962 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -96,7 +96,7 @@ InterpreterCompiler::CompileAheadOfTime( } se::Platform::Id InterpreterCompiler::PlatformId() const { - return sep::kInterpreterPlatformId; + return sep::kXlaInterpreterPlatformId; } HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() @@ -109,11 +109,11 @@ static std::unique_ptr CreateComputationPlacer() { } static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(sep::kInterpreterPlatformId, []() { + xla::Compiler::RegisterCompilerFactory(sep::kXlaInterpreterPlatformId, []() { return xla::MakeUnique(); }); - xla::ComputationPlacer::RegisterComputationPlacer(sep::kInterpreterPlatformId, - &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + sep::kXlaInterpreterPlatformId, &CreateComputationPlacer); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 0cb9b5d8107cd8bf468b07d5fe2a22930d9e8b8c..883063d0f075f5b0d79edc01bcd27a7c579272f4 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -93,7 +93,7 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( TF_ASSIGN_OR_RETURN(std::unique_ptr result, transfer_manager->AllocateShapedBuffer( result_literal->shape(), run_options->allocator(), - run_options->device_ordinal())); + executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( executor, *result_literal, *result)); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 68371910d76f42c0b6d4b1adad9d6a83bdb858e6..3caf9e7b82b21a84197ffe60267d6d953f9547a1 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -28,84 +28,85 @@ host::HostStream *AsExecutorStream(Stream *stream) { return dynamic_cast(stream->implementation()); } -InterpreterExecutor::InterpreterExecutor(const PluginConfig &plugin_config) +XlaInterpreterExecutor::XlaInterpreterExecutor( + const PluginConfig &plugin_config) : plugin_config_(plugin_config) {} -InterpreterExecutor::~InterpreterExecutor() {} +XlaInterpreterExecutor::~XlaInterpreterExecutor() {} -void *InterpreterExecutor::Allocate(uint64 size) { return new char[size]; } +void *XlaInterpreterExecutor::Allocate(uint64 size) { return new char[size]; } -void *InterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, - uint64 offset_bytes, - uint64 /*size_bytes*/) { +void *XlaInterpreterExecutor::AllocateSubBuffer(DeviceMemoryBase *parent, + uint64 offset_bytes, + uint64 /*size_bytes*/) { return parent + offset_bytes; } -void InterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { +void XlaInterpreterExecutor::Deallocate(DeviceMemoryBase *mem) { if (!mem->is_sub_buffer()) { delete[] static_cast(mem->opaque()); } } -bool InterpreterExecutor::Memcpy(Stream *stream, void *host_dst, - const DeviceMemoryBase &dev_src, uint64 size) { +bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, + const DeviceMemoryBase &dev_src, + uint64 size) { AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); return true; } -bool InterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, - const void *host_src, uint64 size) { +bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, + const void *host_src, uint64 size) { AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); return true; } -port::Status InterpreterExecutor::SynchronousMemcpy(DeviceMemoryBase *dev_dst, - const void *host_src, - uint64 size) { +port::Status XlaInterpreterExecutor::SynchronousMemcpy( + DeviceMemoryBase *dev_dst, const void *host_src, uint64 size) { memcpy(dev_dst->opaque(), host_src, size); return port::Status::OK(); } -port::Status InterpreterExecutor::SynchronousMemcpy( +port::Status XlaInterpreterExecutor::SynchronousMemcpy( void *host_dst, const DeviceMemoryBase &dev_src, uint64 size) { memcpy(host_dst, dev_src.opaque(), size); return port::Status::OK(); } -bool InterpreterExecutor::HostCallback(Stream *stream, - std::function callback) { +bool XlaInterpreterExecutor::HostCallback(Stream *stream, + std::function callback) { AsExecutorStream(stream)->EnqueueTask(callback); return true; } -bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, - Stream *other) { +bool XlaInterpreterExecutor::CreateStreamDependency(Stream *dependent, + Stream *other) { AsExecutorStream(dependent)->EnqueueTask( [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } -bool InterpreterExecutor::StartTimer(Stream *stream, Timer *timer) { +bool XlaInterpreterExecutor::StartTimer(Stream *stream, Timer *timer) { dynamic_cast(timer->implementation())->Start(stream); return true; } -bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { +bool XlaInterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { dynamic_cast(timer->implementation())->Stop(stream); return true; } -port::Status InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); return port::Status::OK(); } -DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { +DeviceDescription *XlaInterpreterExecutor::PopulateDeviceDescription() const { internal::DeviceDescriptionBuilder builder; builder.set_device_address_bits(64); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c5d07e906dafb033905c50c604069e80e1ce80cd..77426b0820d2d4e6a3a3216025837de7fa5e5c65 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Declares the InterpreterExecutor class, which is a CPU-only implementation of -// the StreamExecutor interface. For now, this is used for testing and to +// Declares the XlaInterpreterExecutor class, which is a CPU-only implementation +// of the StreamExecutor interface. For now, this is used for testing and to // examine the performance of host-based StreamExecutor code. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ @@ -50,10 +50,10 @@ namespace interpreter { using Args = tensorflow::gtl::ArraySlice; -class InterpreterExecutor : public internal::StreamExecutorInterface { +class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: - explicit InterpreterExecutor(const PluginConfig &plugin_config); - ~InterpreterExecutor() override; + explicit XlaInterpreterExecutor(const PluginConfig &plugin_config); + ~XlaInterpreterExecutor() override; port::Status Init(int device_ordinal, DeviceOptions device_options) override { return port::Status::OK(); diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index cf98ecd7749d61261bf072cdb1882c7603f39556..3cf8506d1c469d7745d26834a51b4ce0eebaa942 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -26,7 +26,7 @@ namespace sei = ::perftools::gputools::interpreter; namespace xla { InterpreterTransferManager::InterpreterTransferManager() - : GenericTransferManager(sei::kInterpreterPlatformId, + : GenericTransferManager(sei::kXlaInterpreterPlatformId, /*pointer_size=*/sizeof(void*)) {} } // namespace xla @@ -38,7 +38,7 @@ CreateInterpreterTransferManager() { static bool InitModule() { xla::TransferManager::RegisterTransferManager( - sei::kInterpreterPlatformId, &CreateInterpreterTransferManager); + sei::kXlaInterpreterPlatformId, &CreateInterpreterTransferManager); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index a60e7fc59f7c5f0b1b24e026b34e195ca0fe5ebb..015e00e1e8edc5c77066b6038f98621862af5440 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -35,17 +35,19 @@ namespace perftools { namespace gputools { namespace interpreter { -InterpreterPlatform::InterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform() : name_("Interpreter") {} -InterpreterPlatform::~InterpreterPlatform() {} +XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id InterpreterPlatform::id() const { return kInterpreterPlatformId; } +Platform::Id XlaInterpreterPlatform::id() const { + return kXlaInterpreterPlatformId; +} -int InterpreterPlatform::VisibleDeviceCount() const { return 1; } +int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } -const string& InterpreterPlatform::Name() const { return name_; } +const string& XlaInterpreterPlatform::Name() const { return name_; } -port::StatusOr InterpreterPlatform::ExecutorForDevice( +port::StatusOr XlaInterpreterPlatform::ExecutorForDevice( int ordinal) { StreamExecutorConfig config; config.ordinal = ordinal; @@ -55,7 +57,7 @@ port::StatusOr InterpreterPlatform::ExecutorForDevice( } port::StatusOr -InterpreterPlatform::ExecutorForDeviceWithPluginConfig( +XlaInterpreterPlatform::ExecutorForDeviceWithPluginConfig( int device_ordinal, const PluginConfig& plugin_config) { StreamExecutorConfig config; config.ordinal = device_ordinal; @@ -64,16 +66,17 @@ InterpreterPlatform::ExecutorForDeviceWithPluginConfig( return GetExecutor(config); } -port::StatusOr InterpreterPlatform::GetExecutor( +port::StatusOr XlaInterpreterPlatform::GetExecutor( const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( config, [&]() { return GetUncachedExecutor(config); }); } port::StatusOr> -InterpreterPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { +XlaInterpreterPlatform::GetUncachedExecutor( + const StreamExecutorConfig& config) { auto executor = port::MakeUnique( - this, port::MakeUnique(config.plugin_config)); + this, port::MakeUnique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ @@ -86,17 +89,17 @@ InterpreterPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { return std::move(executor); } -void InterpreterPlatform::RegisterTraceListener( +void XlaInterpreterPlatform::RegisterTraceListener( std::unique_ptr listener) { LOG(FATAL) << "not yet implemented: register executor trace listener"; } -void InterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { +void XlaInterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { LOG(FATAL) << "not yet implemented: unregister executor trace listener"; } -static void InitializeInterpreterPlatform() { - std::unique_ptr platform(new sep::InterpreterPlatform); +static void InitializeXlaInterpreterPlatform() { + std::unique_ptr platform(new sep::XlaInterpreterPlatform); SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); } @@ -105,7 +108,7 @@ static void InitializeInterpreterPlatform() { } // namespace perftools REGISTER_MODULE_INITIALIZER(interpreter_platform, - sep::InitializeInterpreterPlatform()); + sep::InitializeXlaInterpreterPlatform()); DECLARE_MODULE_INITIALIZER(multi_platform_manager); diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index c66ddb907d1c5a8e99d3178a202a77a72a646ce5..2f71b29be4401a8374cdd0bad5830a632305fc26 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -27,10 +27,10 @@ namespace perftools { namespace gputools { namespace interpreter { -class InterpreterPlatform : public Platform { +class XlaInterpreterPlatform : public Platform { public: - InterpreterPlatform(); - ~InterpreterPlatform() override; + XlaInterpreterPlatform(); + ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -60,7 +60,7 @@ class InterpreterPlatform : public Platform { // Cache of created StreamExecutors. ExecutorCache executor_cache_; - SE_DISALLOW_COPY_AND_ASSIGN(InterpreterPlatform); + SE_DISALLOW_COPY_AND_ASSIGN(XlaInterpreterPlatform); }; } // namespace interpreter diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.cc b/tensorflow/compiler/xla/service/interpreter/platform_id.cc index 1a0373cf86e26b564e0e732e8de1a0a5d868bfa6..b7fb365b70db7235764435305085e36869cbb13a 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform_id.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.cc @@ -18,7 +18,7 @@ namespace perftools { namespace gputools { namespace interpreter { -PLATFORM_DEFINE_ID(kInterpreterPlatformId); +PLATFORM_DEFINE_ID(kXlaInterpreterPlatformId); } // namespace interpreter } // namespace gputools diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.h b/tensorflow/compiler/xla/service/interpreter/platform_id.h index 905efef1690d3bd32461353fe32dd394eb6bca9e..292f958449b52ff2f522bd31f115079b4f7e0835 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform_id.h +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.h @@ -22,7 +22,7 @@ namespace perftools { namespace gputools { namespace interpreter { -extern const Platform::Id kInterpreterPlatformId; +extern const Platform::Id kXlaInterpreterPlatformId; } // namespace interpreter } // namespace gputools diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index fce135ef61a7868386b869def1a79167c428d928..2494569db53f260b900b3d5d3d0d2da5b1fc5f73 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -53,6 +53,13 @@ limitations under the License. namespace xla { +// For now moving only one API here, but we should have a single top level +// anonymous namespace, instead of three or four spread all over this file. +namespace { + + +} // namespace + std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -115,17 +122,34 @@ LayoutConstraints::LayoutConstraints( } } +PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( + const HloInstruction* instruction) const { + auto it = buffer_sets_cache_.find(instruction); + if (it != buffer_sets_cache_.end()) { + return it->second.get(); + } + auto& buffer_set = + buffer_sets_cache_ + .emplace(instruction, MakeUnique()) + .first->second; + const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); + points_to_set.ForEachElement( + [&buffer_set](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + buffer_set->insert(buffers.begin(), buffers.end()); + }); + return buffer_set.get(); +} + bool LayoutConstraints::OperandBufferForwarded( const HloInstruction* instruction, int64 operand_no) const { // The operand is potentially forwarded if the intersection of points-to sets // of the operand and the instruction is non-empty. - auto output_buffers = - points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); - auto operand_buffers = - points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) - .CreateFlattenedSet(); - for (const LogicalBuffer* output_buffer : output_buffers) { - if (operand_buffers.count(output_buffer) > 0) { + PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); + PointsToSet::BufferSet* operand_buffers = + GetBufferSet(instruction->operand(operand_no)); + for (const LogicalBuffer* output_buffer : *output_buffers) { + if (operand_buffers->count(output_buffer) > 0) { return true; } } @@ -512,6 +536,36 @@ Status LayoutAssignment::AddMandatoryConstraints( body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( body_layout.result_shape(), instruction, 0)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + // The layout of the true and false computations must match, and must + // be the layout of the kConditional instruction. + TF_RET_CHECK(instruction->operand_count() == 3); + + HloComputation* true_computation = instruction->true_computation(); + HloComputation* false_computation = instruction->false_computation(); + const HloInstruction* true_operand = instruction->operand(1); + const HloInstruction* false_operand = instruction->operand(2); + + TF_RET_CHECK(true_computation->num_parameters() == 1); + TF_RET_CHECK(false_computation->num_parameters() == 1); + ComputationLayout& true_computation_layout = + FindOrDie(computation_layouts_, true_computation); + ComputationLayout& false_computation_layout = + FindOrDie(computation_layouts_, false_computation); + + DCHECK(ShapeUtil::Compatible(true_operand->shape(), + true_computation_layout.parameter_shape(0))); + DCHECK(ShapeUtil::Compatible( + false_operand->shape(), false_computation_layout.parameter_shape(0))); + + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + true_computation_layout.result_shape(), instruction)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + true_computation_layout.parameter_shape(0), instruction, 1, + /*mandatory=*/true)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + false_computation_layout.parameter_shape(0), instruction, 2, + /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { if (!CustomCallRequiresMajorFirstLayout(instruction)) { continue; @@ -598,6 +652,33 @@ Status CheckWhileLayout(HloInstruction* while_inst, return Status::OK(); } +Status CheckConditionalLayout( + HloInstruction* instruction, + const ComputationLayout& true_computation_layout, + const ComputationLayout& false_computation_layout) { + HloComputation* true_computation = instruction->true_computation(); + HloComputation* false_computation = instruction->false_computation(); + const HloInstruction* true_operand = instruction->operand(1); + const HloInstruction* false_operand = instruction->operand(2); + + TF_RET_CHECK(true_computation_layout.result_layout() == + false_computation_layout.result_layout()); + TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( + true_computation->root_instruction()->shape())); + TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( + false_computation->root_instruction()->shape())); + TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( + true_operand->shape())); + TF_RET_CHECK( + false_computation_layout.parameter_layout(0).MatchesLayoutInShape( + false_operand->shape())); + return Status::OK(); +} + // Fusion parameters must match the layout of the fusion instructions operands, // and the root of the fusion expression must match the layout of the fusion // instruction. @@ -642,6 +723,99 @@ Status CheckConstantLayout(HloInstruction* constant) { } // namespace +StatusOr LayoutAssignment::CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction) { + TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); + DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) + << ShapeUtil::HumanString(shape_with_layout) << " " + << ShapeUtil::HumanString(instruction->shape()) + << " instruction: " << instruction->ToString(); + + if (ShapeUtil::IsTuple(instruction->shape())) { + // Deep-copy tuples. + std::vector element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + SetupCopiedInstruction(*instruction, gte, {i}); + // Recurse to copy each elements. + TF_ASSIGN_OR_RETURN( + HloInstruction * element_copy, + CreateCopyWithNewLayout( + ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); + element_copies.push_back(element_copy); + } + // Gather element copies into a tuple with a new Tuple instruction. + HloInstruction* tuple_copy = instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); + SetupCopiedInstruction(*instruction, tuple_copy, {}); + LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, tuple_copy->mutable_shape())); + return tuple_copy; + } else if (ShapeUtil::IsArray(instruction->shape())) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + SetupCopiedInstruction(*instruction, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, copy->mutable_shape())); + + return copy; + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +// Creates a copy of the given operand if the operand's layout does not match +// the given layout. This copy replaces the use in the given instruction. Tuple +// operands will be deep-copied. +Status LayoutAssignment::CopyOperandIfLayoutsDiffer( + const ShapeLayout& operand_layout, HloInstruction* instruction, + int64 operand_no) { + HloInstruction* operand = instruction->mutable_operand(operand_no); + TF_RET_CHECK(operand_layout.LayoutIsSet()); + TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); + + if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + // Operand layout already matches our constraint. Nothing to do. + return Status::OK(); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); + + return instruction->ReplaceOperandWith(operand_no, operand_copy); +} + +void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, + HloInstruction* copy, + const ShapeIndex& index) { + if (instruction.has_sharding()) { + // If the index is empty, we want to copy the whole sharding, in case the + // sharding is a tuple sharding. + HloSharding sharding = + !index.empty() && instruction.sharding().IsTuple() + ? instruction.sharding().GetSubSharding(instruction.shape(), index) + : instruction.sharding(); + // We propagate the sharding to the copied instruction only if it is a + // special sharding, like tiled ones, or special devices like the + // HostCompute module. + // Otherwise it is preferable to leave the new instruction without device, + // and let the automatic device placer to choose the best location. + if (!sharding.HasUniqueDevice() || + HloSharding::IsReservedDevice(sharding.UniqueDevice().ValueOrDie())) { + copy->set_sharding(sharding); + } + } + copy->set_metadata(instruction.metadata()); +} + Status LayoutAssignment::CheckLayouts(HloModule* module) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -710,6 +884,13 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->while_condition()), FindOrDie(computation_layouts_, instruction->while_body()))); break; + case HloOpcode::kConditional: + TF_RETURN_IF_ERROR(CheckConditionalLayout( + instruction, + FindOrDie(computation_layouts_, instruction->true_computation()), + FindOrDie(computation_layouts_, + instruction->false_computation()))); + break; default: break; } @@ -1165,77 +1346,6 @@ StatusOr InferArrayLayout( return *first_buffer_layout; } -// Creates and returns a copy of the given instruction with a different -// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple -// instruction producing the copy is returned. -StatusOr CreateCopyWithNewLayout( - const Shape& shape_with_layout, HloInstruction* instruction) { - TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); - DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) - << ShapeUtil::HumanString(shape_with_layout) << " " - << ShapeUtil::HumanString(instruction->shape()) - << " instruction: " << instruction->ToString(); - - if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. - std::vector element_copies; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); - ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); - } - // Gather element copies into a tuple with a new Tuple instruction. - HloInstruction* tuple_copy = instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); - LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, tuple_copy->mutable_shape())); - return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { - HloInstruction* copy = - instruction->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, copy->mutable_shape())); - - return copy; - } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); - } -} - -// Creates a copy of the given operand if the operand's layout does not match -// the given layout. This copy replaces the use in the given instruction. Tuple -// operands will be deep-copied. -Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no) { - HloInstruction* operand = instruction->mutable_operand(operand_no); - TF_RET_CHECK(operand_layout.LayoutIsSet()); - TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { - // Operand layout already matches our constraint. Nothing to do. - return Status::OK(); - } - - TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, - CreateCopyWithNewLayout(operand_layout.shape(), operand)); - - return instruction->ReplaceOperandWith(operand_no, operand_copy); -} - // For fusion instructions, set the layout of each fused parameter instruction // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion @@ -1474,6 +1584,13 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // infeeds. Clearing the layouts here avoids hiding potential bugs in the // layout assignment pass that may accidently use the existing layout. for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } if (instruction->opcode() != HloOpcode::kInfeed) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 29018584487cabfd740d7914625c2a50f552d6ff..ae4986d6ad9bc3de100eab9cc38b709bb56c7813 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -199,6 +200,11 @@ class LayoutConstraints { string ToString() const; private: + // Find a bufferset in the bufferset cache. This is useful since we can + // currently create the flattened buffer set for the same instruction many + // times, which is often slow. + PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; + // The set of BufferLayoutConstraints applied to the computation. std::unordered_map buffer_constraints_; @@ -221,6 +227,10 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; + mutable tensorflow::gtl::FlatMap> + buffer_sets_cache_; + HloComputation* computation_; }; @@ -393,14 +403,37 @@ class LayoutAssignment : public HloPassInterface { Status CheckLayouts(HloModule* module); ComputationLayout* entry_computation_layout_; - ChannelLayoutConstraints* channel_layout_constraints_; protected: + // Sets up the copy instruction according to the characteristic (sharding, + // metadata, ...) of the reference instruction. The index argument is used + // when the instruction is a tuple, and in such case the index represents + // the location from where the copy instruction was created from. + // If the index is empty, the whole sharding will be propagated, even in case + // the intruction has a tuple sharding. + static void SetupCopiedInstruction(const HloInstruction& instruction, + HloInstruction* copy, + const ShapeIndex& index); + + // Creates and returns a copy of the given instruction with a different + // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple + // instruction producing the copy is returned. + static StatusOr CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction); + + // Creates a copy of the given operand if the operand's layout does not match + // the given layout. This copy replaces the use in the given instruction. + // Tuple operands will be deep-copied. + static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + ChannelLayoutConstraints* channel_layout_constraints_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index e269a13459f1146f1d2952870399827d9e705e38..4b1c9bad41de8030cf14bc6d1c0db21b9c56c3bf 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -590,6 +590,85 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { transpose->shape(), {2, 3, 0, 1})); } +// TransposeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = builder.AddInstruction( + HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1})); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(), + hlo->shape(), hlo->dimensions()), + "LayoutUtil::HasLayout"); +} + +// ReshapeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = + builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param)); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH( + ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()), + "LayoutUtil::HasLayout"); +} + +// Check that the computation below doesn't crash the compiler. +// +// Within a fusion computation, only the parameters and result get assigned a +// layout. When we run the algebraic simplifier on this computation post layout +// assignment, it should not call TransposeIsBitcast on the `transpose` node +// inside the fusion computation as TransposeIsBitcast checks both input_shape +// and output_shape have layouts. +TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { + const char* module_str = R"( + HloModule test_module + + fused_computation { + param_1 = f32[2,2,2]{2,1,0} parameter(1) + transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1} + reduce_1 = f32[] parameter(0) + broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={} + ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1) + } + + ENTRY entry_computation { + fusion.1 = f32[2,2,2]{2,1,0} parameter(1) + reduce.1 = f32[] parameter(0) + fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation + ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2) + } + )"; + + auto module = tools::Parse(module_str).ValueOrDie(); + + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + EXPECT_EQ( + ::tensorflow::Status::OK(), + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); +} + // A GTE inside of a fusion node inherits the layout of its operand (which // should, if we keep following operands, eventually be a parameter). TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { @@ -629,33 +708,113 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - HloComputation* fused_computation = *std::find_if( - module->computations().begin(), module->computations().end(), - [](const HloComputation* c) { return c->name() == "fused_computation"; }); - - auto fused_instr = [&](const string& name) { - auto it = std::find_if( - fused_computation->instructions().begin(), - fused_computation->instructions().end(), - [&](const HloInstruction* i) { return i->name() == name; }); - CHECK(it != fused_computation->instructions().end()); - return *it; + auto layout_of = [&](tensorflow::StringPiece name) { + return FindInstruction(module.get(), name) + ->shape() + .layout() + .minor_to_major(); }; - EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(), - ElementsAre(0, 1, 2)); - EXPECT_THAT( - fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(), - ElementsAre(1, 2, 0)); - EXPECT_THAT( - fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(), - ElementsAre(2, 0, 1)); - EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(), + EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(module.get(), "gte1") + ->shape() + .tuple_shapes(0) + .layout() + .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(), + EXPECT_THAT(FindInstruction(module.get(), "gte1") + ->shape() + .tuple_shapes(1) + .layout() + .minor_to_major(), ElementsAre(2, 0, 1)); - EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(), - ElementsAre(2, 1, 0)); +} + +TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { + auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewModule(); + Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); + Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); + Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + auto pred = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {}), "param2")); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + + auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch"); + { + auto param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tshape, "param")); + auto gte0 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, param, 0)); + auto gte1 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, param, 1)); + auto add = true_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1)); + true_builder.AddInstruction(HloInstruction::CreateTuple({add})); + } + HloComputation* true_computation = + module->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); + { + Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1}); + false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tshape, "param")); + // Using infeed as layout assignment does not mess up with it. + auto infeed = + false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + } + HloComputation* false_computation = + module->AddEmbeddedComputation(false_builder.Build()); + builder.AddInstruction(HloInstruction::CreateConditional( + result_tshape, pred, tuple, true_computation, tuple, false_computation)); + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + AssignLayouts(module.get(), &computation_layout); + + const HloInstruction* true_root = true_computation->root_instruction(); + const HloInstruction* false_root = false_computation->root_instruction(); + EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple); + EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple); + + const HloInstruction* true_result = true_root->operand(0); + const HloInstruction* false_result = false_root->operand(0); + EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(), + false_result->shape().layout())); + EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); +} + +TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant0->shape(), HloOpcode::kBitcast, constant0)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + LayoutAssignment layout_assignment(&computation_layout); + Status error_status = layout_assignment.Run(module.get()).status(); + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT( + error_status.error_message(), + ::testing::HasSubstr( + "Unexpected bitcast operation seen during layout assignment")); } } // namespace diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 2c2a02f6375343d67dfb155bbb03729ff6e490d2..f8b309488eeb5391b1cad5db760934ec1f7e3521 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase { CHECK_NOTNULL(module_.get()); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = - HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie(); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index f98fc0400a7d827a29dcddc5eecf9a4a01e76590..911b243fe28a5baf8a4b8ed752b892265f5388ac 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -14,12 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/core/platform/denormal.h" + +#ifdef __FAST_MATH__ +#error "Don't build XLA with -ffast-math" +#endif namespace xla { StatusOr>> LLVMCompiler::Compile( std::vector> modules, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { + // Tensorflow tries to enable the following behaviors in all its threads: + // + // - Denormals are zero (DAZ): roughly, operations treat denormal floats as + // zero. + // - Flush denormals to zero (FTZ): roughly, operations produce zero instead + // of denormal floats. + // + // In theory enabling these shouldn't matter since the compiler should ideally + // not leak its environment into generated code, but we turn off DAZ and FTZ + // to get some defense-in-depth. + tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; + std::vector> result; for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 37261ed1e665ebed9685751161a412ad114a9e96..f1e7fc29532ce7e6841010a5258f4000a7c70383 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -169,17 +169,3 @@ cc_library( "@llvm//:core", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 6384c7f46f5ebbedaeda232b40095611a5d738a4..3312a888443233139841ce7a5e3173f907605e1d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,18 +29,13 @@ limitations under the License. namespace xla { namespace llvm_ir { -IrArray::Index::Index(llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) - : multidim_(ShapeUtil::Rank(shape)), - linear_(linear), - layout_(shape.layout()), - dims_(shape.dimensions().begin(), shape.dimensions().end()) { - CHECK(LayoutUtil::HasLayout(shape)) - << "Shape " << ShapeUtil::HumanStringWithLayout(shape) - << " should have a layout."; +static void Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) { int64 divisor = 1; - for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) { - int64 dimension = layout_.minor_to_major(i); + const Layout& layout = shape.layout(); + for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { + int64 dimension = layout.minor_to_major(i); int64 size_of_current_dimension = shape.dimensions(dimension); // If i is not the last dimension, compute @@ -54,16 +49,28 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); - if (i < layout_.minor_to_major_size() - 1) { - multidim_[dimension] = ir_builder->CreateURem( + if (i < layout.minor_to_major_size() - 1) { + (*multidim)[dimension] = ir_builder->CreateURem( quot, ir_builder->getInt64(size_of_current_dimension)); } else { - multidim_[dimension] = quot; + (*multidim)[dimension] = quot; } divisor *= size_of_current_dimension; } } +IrArray::Index::Index(llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) + : multidim_(ShapeUtil::Rank(shape)), + linear_(linear), + layout_(shape.layout()), + dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK(LayoutUtil::HasLayout(shape)) + << "Shape " << ShapeUtil::HumanStringWithLayout(shape) + << " should have a layout."; + Delinearize(&multidim_, linear, shape, ir_builder); +} + IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), @@ -83,7 +90,6 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, dims_(shape.dimensions().begin(), shape.dimensions().end()) { CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); - linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder); } IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) @@ -106,16 +112,13 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) } } -// Returns whether given linear index valid on given shape. +// Returns whether the given linear index is valid on the given shape. bool IrArray::Index::LinearValidOnShape(const Shape& a) const { - auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_); + auto b = ShapeUtil::MakeShape(a.element_type(), dims_); *b.mutable_layout() = layout_; return linear_ != nullptr && - ContainersEqual( - ShapeUtil::StripDegenerateDimensions(a).dimensions(), - ShapeUtil::StripDegenerateDimensions(b).dimensions()) && - LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(), - ShapeUtil::StripDegenerateDimensions(b).layout()); + ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) && + ShapeUtil::ReshapeIsBitcast(a, b); } IrArray::Index IrArray::Index::SourceIndexOfReshape( @@ -160,7 +163,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } } - if (linear() != nullptr && + if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape) && ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } @@ -195,13 +199,111 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose( llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); - if (linear() != nullptr && + + if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && + LayoutUtil::HasLayout(shape) && ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { return Index(operand_multidim_index, linear(), operand_shape); } + return Index(operand_multidim_index); } +IrArray::Index IrArray::Index::SourceIndexOfBitcast( + const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const { + CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape)); + // In case the bitcast is just a reshape, we can use SourceIndexOfReshape() + // instead. This will reuse linear() if possible, so we don't have to build a + // new 'linear_index'. + if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) { + return SourceIndexOfReshape(shape, operand_shape, builder); + } + + // First linearize the index coming from the output of the bitcast. We want + // the physical index of the element in the buffer. This is like Linearize, + // but takes the layout into account. + int64 scale = 1; + llvm::Value* linear_index = builder->getInt64(0); + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { + linear_index = builder->CreateAdd( + linear_index, + builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + /*HasNUW=*/true, /*HasNSW=*/true), + "", /*HasNUW=*/true, /*HasNSW=*/true); + scale *= shape.dimensions(dimension); + } + + // Now delinearize it for the input of the bitcast. + std::vector multi_index(operand_shape.dimensions_size()); + Delinearize(&multi_index, linear_index, operand_shape, builder); + + return Index(multi_index, linear_index, operand_shape); +} + +IrArray::Index IrArray::Index::SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const { + int64 rank = ShapeUtil::Rank(operand_shape); + std::vector source_index(rank); + for (int64 i = 0; i < rank; ++i) { + source_index[i] = multidim_[dimension_mapping[i]]; + } + if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || + !LayoutUtil::HasLayout(shape)) { + return Index(source_index); + } + // High-level idea: we can reuse the linear index if the broadcasted + // dimensions are contiguous, and this part of the operation is a bitcast. + // The other dimensions can be masked out with a div and a mod operation. + std::vector logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(shape.layout()); + int64 output_rank = ShapeUtil::Rank(shape); + // The minimum physical dimension that is broadcasted. + int64 min_broadcasted_dimension = output_rank; + // The maximum physical dimension that is broadcasted. + int64 max_broadcasted_dimension = -1; + for (int64 i = 0; i < rank; ++i) { + int64 physical_dim = logical_to_physical[dimension_mapping[i]]; + min_broadcasted_dimension = + std::min(min_broadcasted_dimension, physical_dim); + max_broadcasted_dimension = + std::max(max_broadcasted_dimension, physical_dim); + } + bool contiguous_broadcast_dimensions = + max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; + if (!contiguous_broadcast_dimensions) { + return Index(source_index); + } + // Check if the mapped dimensions are a bitcast. + std::vector operand_logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(operand_shape.layout()); + for (int64 i = 0; i < rank; ++i) { + if (operand_logical_to_physical[i] != + logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { + return Index(source_index); + } + } + llvm::Value* linear = linear_; + int64 divisor = 1; + for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) { + divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + if (divisor > 1) { + linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + } + if (min_broadcasted_dimension > 0) { + int64 mod = 1; + for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension; + ++i) { + mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); + } + linear = builder->CreateURem(linear, builder->getInt64(mod)); + } + return Index(source_index, linear, operand_shape); +} + llvm::Value* IrArray::Index::Linearize( tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 387d4629125cbb791840e943013188d14159908a..06cfb2a36c56c5fdece7140e469379f8394111fa 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -76,8 +76,7 @@ class IrArray { llvm::IRBuilder<>* ir_builder); // Constructs an index from the given multi-dimensional index and the shape - // that it indexes into. Also, computes the linear index according to - // "shape". + // that it indexes into. // // Precondition: "shape" has a layout. Index(tensorflow::gtl::ArraySlice multidim, @@ -134,6 +133,18 @@ class IrArray { tensorflow::gtl::ArraySlice dimension_mapping, llvm::IRBuilder<>* builder) const; + // Given that "this" is the target index of a bitcast from `operand_shape` + // to `shape`, returns the source index. + Index SourceIndexOfBitcast(const Shape& shape, const Shape& operand_shape, + llvm::IRBuilder<>* builder) const; + + // Given that "this" is the target index of a broadcast from `operand_shape` + // to `shape` with the given dimension mapping, returns the source index. + Index SourceIndexOfBroadcast( + const Shape& shape, const Shape& operand_shape, + tensorflow::gtl::ArraySlice dimension_mapping, + llvm::IRBuilder<>* builder) const; + // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 22141e7e00756483957f9cd4bc065a64556e854c..ec04239b4f9112134ba876fdfbb3905a3baf1f72 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -65,7 +66,7 @@ llvm::StringRef AsStringRef(tensorflow::StringPiece str) { std::unique_ptr DropConstantInitializers( const llvm::Module& module) { - std::unique_ptr cloned_module = CloneModule(&module); + std::unique_ptr cloned_module = CloneModule(module); for (llvm::GlobalVariable& global_var : cloned_module->globals()) { global_var.setInitializer(nullptr); global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage); @@ -106,8 +107,10 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = ir_builder->CreateFCmpUGE(lhs_value, rhs_value); return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); } else { - return EmitCallToIntrinsic(llvm::Intrinsic::maxnum, {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder); + auto cmp_ge = ir_builder->CreateFCmpOGE(lhs_value, rhs_value); + auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = ir_builder->CreateOr(cmp_ge, lhs_is_nan); + return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); } } @@ -117,8 +120,10 @@ llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = ir_builder->CreateFCmpULE(lhs_value, rhs_value); return ir_builder->CreateSelect(cmp, lhs_value, rhs_value); } else { - return EmitCallToIntrinsic(llvm::Intrinsic::minnum, {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder); + auto cmp_le = ir_builder->CreateFCmpOLE(lhs_value, rhs_value); + auto lhs_is_nan = ir_builder->CreateFCmpUNE(lhs_value, lhs_value); + auto sel_lhs = ir_builder->CreateOr(cmp_le, lhs_is_nan); + return ir_builder->CreateSelect(sel_lhs, lhs_value, rhs_value); } } @@ -758,7 +763,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { + if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index b6b918ec78a27b90325f72eea14b97f9aee43c54..3978acc132f34b8b195d3772ccf71d0d467984db 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -88,12 +88,12 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } } -IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( +std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name) { if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; - return IrArray::Index(); + return {IrArray::Index()}; } // Create loop nest with one for-loop for each dimension of the target shape. @@ -121,12 +121,14 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); CHECK_NOTNULL(exit_bb_); - return array_index; + return {array_index}; } tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { - IrArray::Index array_index = EmitIndexAndSetExitBasicBlock(loop_name); - TF_RETURN_IF_ERROR(body_emitter_(array_index)); + for (const IrArray::Index& array_index : + EmitIndexAndSetExitBasicBlock(loop_name)) { + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + } // Set the insertion point of ir_builder_ to the loop exit, so that // code emitted for later instructions will be correctly placed. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 0fc528439a0d5bf8382dfcf2d8b3051f8900bf1d..9ff497aecd0bc964c929205c7fd410cca87d9b77 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -63,11 +63,12 @@ class LoopEmitter { // Emits a loop nest (with a yet-to-be-filled loop body) that iterates through // every element in the given shape. Returns the multi-dimensional index that - // specifies the element. - IrArray::Index EmitIndexAndSetExitBasicBlock() { + // specifies the element, will return multiple indices if the loop is + // unrolled. + std::vector EmitIndexAndSetExitBasicBlock() { return EmitIndexAndSetExitBasicBlock(/*loop_name=*/""); } - virtual IrArray::Index EmitIndexAndSetExitBasicBlock( + virtual std::vector EmitIndexAndSetExitBasicBlock( tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 07f989d4faea199e812e54d2ae74d3ff9e7fa19a..499f280211aacd00e79b3ca0ddb3413f933b02da 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -69,6 +69,68 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} +namespace { + +// Retrieves the parameter metadata for the given computation and parameter +// number. +// +// If the parameter number is invalid for this computation, nullopt is +// returned. When the return value has_value(), nullptr will never be +// the held value. +tensorflow::gtl::optional ParameterMetadata( + const XlaComputation& computation, int parameter_number) { + for (const HloComputationProto& comp : computation.proto().computations()) { + if (comp.id() == computation.proto().entry_computation_id()) { + for (const HloInstructionProto& instr : comp.instructions()) { + if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && + instr.parameter_number() == parameter_number) { + if (!instr.has_metadata()) { + return tensorflow::gtl::nullopt; + } + return &instr.metadata(); + } + } + } + } + return tensorflow::gtl::nullopt; +} + +ExecutionOptions CreateExecutionOptions( + const ExecutableBuildOptions& build_options, + const ProgramShape* program_shape) { + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (build_options.hlo_profile().has_value()) { + execution_options.mutable_debug_options()->set_xla_hlo_profile( + *build_options.hlo_profile()); + } + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.dump_optimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_optimized_hlo_proto_to( + build_options.dump_optimized_hlo_proto_to().value()); + } + if (build_options.dump_per_pass_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_per_pass_hlo_proto_to( + build_options.dump_per_pass_hlo_proto_to().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); + } else { + *execution_options.mutable_shape_with_output_layout() = + program_shape->result(); + LayoutUtil::SetToDefaultLayout( + execution_options.mutable_shape_with_output_layout()); + } + return execution_options; +} + +} // namespace + StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -118,30 +180,78 @@ StatusOr> LocalService::CompileExecutable( *build_options.result_layout(), program_shape->result())); } - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (build_options.generate_hlo_graph().has_value()) { - execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( - build_options.generate_hlo_graph().value()); + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, program_shape.get()); + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(*program_shape, argument_layouts, + &execution_options, user_computation)); + + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); + + return BuildExecutable(versioned_handle, std::move(module_config), + execute_backend_.get(), executor, + build_options.device_allocator()); +} + +StatusOr> LocalService::CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& build_options) { + const HloModuleProto& proto = computation.proto(); + TF_RET_CHECK(proto.has_program_shape()); + const ProgramShape& program_shape = proto.program_shape(); + + // Validate incoming layouts. + if (argument_layouts.size() != program_shape.parameters_size()) { + return InvalidArgument( + "Invalid number of arguments for computation: expected %d, got %zu.", + program_shape.parameters_size(), argument_layouts.size()); + } + + for (int i = 0; i < argument_layouts.size(); ++i) { + const Shape& argument_shape = *argument_layouts[i]; + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { + tensorflow::gtl::optional metadata = + ParameterMetadata(computation, /*parameter_number=*/i); + auto metadata_string = [&metadata]() -> string { + if (!metadata.has_value()) { + return ""; + } + CHECK(metadata.value() != nullptr); + const OpMetadata& m = *metadata.value(); + if (!m.source_file().empty()) { + return tensorflow::strings::Printf( + " (%s:%d)", m.source_file().c_str(), m.source_line()); + } + return ""; + }; + return InvalidArgument( + "Invalid argument shape for argument %d%s, expected %s, got %s.", i, + metadata_string().c_str(), + ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(argument_shape).c_str()); + } } if (build_options.result_layout() != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *build_options.result_layout(); - } else { - *execution_options.mutable_shape_with_output_layout() = - program_shape->result(); - LayoutUtil::SetToDefaultLayout( - execution_options.mutable_shape_with_output_layout()); + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape.result())); } + + ExecutionOptions execution_options = + CreateExecutionOptions(build_options, &program_shape); + TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - *user_computation)); + CreateModuleConfig(program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); - return BuildExecutable(versioned_handle, std::move(module_config), + return BuildExecutable(proto, std::move(module_config), execute_backend_.get(), executor, build_options.device_allocator()); } diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 15e120685e1be9190d49fdaf5ed6706bdf991a6c..06567cabd6eb28aae53881613cd6beb78e25e222 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -50,6 +51,18 @@ class LocalService : public Service { const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options); + // Builds an Executable with the given XlaComputation, argument layouts and + // options. If result_layout is non-null, then the executable is compiled to + // produce a result of the given layout. If device_allocator is non-null, + // then the compiler may use it to allocate temp space on the device. The + // compiler is responsible for freeing any memory it allocates this way. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> CompileExecutable( + const XlaComputation& computation, + const tensorflow::gtl::ArraySlice argument_layouts, + const ExecutableBuildOptions& build_options); + // Returns the device ordinal that corresponds to the given replica number. // // This returns an error if there is not a one-to-one correspondence of diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h new file mode 100644 index 0000000000000000000000000000000000000000..5d4963807721eb177400131fa16a69f32fb431ab --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -0,0 +1,1014 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// A pattern matcher for HloInstructions, Shapes, and Layouts. +// +// The Match function's first argument must be HloInstruction*, Shape*, or +// Layout*. The second argument is a pattern that will be matched against the +// first argument, as described below. +// +// Patterns are constructed using the match::Op, match::Shape, or match::Layout +// functions. By default, the returned patterns will match any HloInstruction, +// Shape, or Layout, respectively. However the match can be made more specific +// by using the pattern's modifier methods, for example: +// +// match::Op().WithOpcode(HloOpcode::kAdd).WithOperand( +// 0, match::Op().WithOpcode(HloOpcode::kConstant)) +// +// This pattern will match Add instructions whose first operand is a constant. +// +// Each pattern type has the following modifiers: +// +// Op(): +// - WithName: match operations with the given name +// - WithOpcode: match operations with the given opcode +// - WithShape: match operations whose shape matches the given pattern +// - WithOperand: match operations whose operand matches the given pattern +// +// Shape(): +// - EqualTo: matches shapes that are equal to the argument +// - CompatibleTo: matches shapes that are compatible to the argument +// - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes +// - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format +// - WithLayout: match shapes whose layout matches the given pattern +// - WithLayoutEqualTo: matches shapes whose layouts equal the argument +// - WithSubshape: matches tuple shapes whose subshape matches the given +// pattern +// - WithSubshapeEqualTo: matches shapes with a subshape equal the argument +// - WithElementType: matches array/scalar shapes with the given element +// type +// - WithRank: matches array/scalar types with the given rank +// +// Layout(): +// - EqualTo: matches layouts that are equal to the argument +// - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse +// format +// +// Op(), Shape(), and Layout() may be passed an argument of type +// HloInstruction**, Shape**, or Layout**, respectively, or const versions of +// these pointers. If the pattern is matched, the address of the matched value +// will be "captured" and stored at this location. +// +// For example: +// HloInstruction* foo = ...; +// HloInstruction* matched_operand; +// CHECK(Match(foo, +// match::Op().WithOperand(0, match::Op(&matched_operand)))); +// +// Helpers are provided for common nullary, unary, binary, and ternary +// instructions. These helpers can be called with no arguments, in which case +// they will match any instruction matching the opcode. They may also be called +// with matches for the operands and with an optional capture. (The capture must +// be the first argument.) Some examples of these helpers and their equivalents +// are provided below. +// +// Example nullary instruction: +// Recv() == Op().WithOpcode(HloOpcode::kRecv) +// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv) +// +// Example unary instruction: +// Abs() == Op().WithOpcode(HloOpcode::kAbs) +// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&a))) +// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&b)) +// +// Example binary instruction: +// Add() == Op().WithOpcode(HloOpcode::kAdd) +// Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) +// .WithOperand(0, Op(&a)) +// .WithOperand(1, Op(&b)) +// Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) +// .WithOperand(0, Op(&b)) +// .WithOperand(1, Op(&c)) +// +// Example ternary instruction: +// Clamp() == Op().WithOpcode(HloOpcode::kClamp) +// Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) +// .WithOperand(0, Op(&a)) +// .WithOperand(1, Op(&b)) +// .WithOperand(2, Op(&c)) +// Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) +// .WithOperand(0, Op(&b)) +// .WithOperand(1, Op(&c)) +// .WithOperand(2, Op(&d)) +// +template +bool Match(Value* value, const Pattern& pattern) { + return pattern.Match(value); +} + +namespace match { + +namespace detail { + +template +class LayoutPattern; + +// The base LayoutPattern implementation. Matches only if the layout is not +// nullptr. +class LayoutPatternBaseImpl { + public: + bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } +}; + +// A LayoutPattern implementation that matches only if the layout equals a +// Layout proto. +template +class LayoutPatternEqualImpl { + public: + explicit constexpr LayoutPatternEqualImpl(const Previous& previous, + const ::xla::Layout* layout) + : previous_(previous), layout_(layout) {} + + bool Match(const ::xla::Layout* layout) const { + return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + } + + private: + Previous previous_; + const ::xla::Layout* layout_; +}; + +// A LayoutPattern implementation that matches only if the layout has a given +// format. +template +class LayoutPatternFormatImpl { + public: + explicit constexpr LayoutPatternFormatImpl(const Previous& previous, + Format format) + : previous_(previous), format_(format) {} + + bool Match(const ::xla::Layout* layout) const { + return previous_.Match(layout) && layout->format() == format_; + } + + private: + Previous previous_; + Format format_; +}; + +// A pattern that matches Layouts. +template +class LayoutPattern { + public: + explicit constexpr LayoutPattern(const Impl& impl, + LayoutType** matched_layout) + : impl_(impl), matched_layout_(matched_layout) {} + + // Returns true and captures the layout iff it matches the pattern. + bool Match(const ::xla::Layout* layout) const { + if (impl_.Match(layout)) { + if (matched_layout_) { + *matched_layout_ = layout; + } + return true; + } + return false; + } + + // Returns true and captures the layout iff it matches the pattern. + bool Match(::xla::Layout* layout) const { + if (impl_.Match(layout)) { + if (matched_layout_) { + *matched_layout_ = layout; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the layout equals the given proto. + // The layout must outlive the returned pattern. + constexpr LayoutPattern> EqualTo( + const Layout* layout) const { + return LayoutPattern>( + LayoutPatternEqualImpl(impl_, layout), matched_layout_); + } + + // Modifies the pattern to match only if the layout has a dense format. + constexpr LayoutPattern> + WithDenseFormat() const { + return LayoutPattern>( + LayoutPatternFormatImpl(impl_, DENSE), matched_layout_); + } + + // Modifies the pattern to match only if the layout has a sparse format. + constexpr LayoutPattern> + WithSparseFormat() const { + return LayoutPattern>( + LayoutPatternFormatImpl(impl_, SPARSE), matched_layout_); + } + + private: + Impl impl_; + LayoutType** matched_layout_; +}; + +} // namespace detail + +// Creates a layout pattern that will capture the matched layout in the +// argument. +inline constexpr detail::LayoutPattern +Layout(const ::xla::Layout** matched_layout = nullptr) { + return detail::LayoutPattern( + detail::LayoutPatternBaseImpl(), matched_layout); +} + +// Creates a layout pattern that will capture the matched layout in the +// argument. +inline constexpr detail::LayoutPattern<::xla::Layout, + detail::LayoutPatternBaseImpl> +Layout(::xla::Layout** matched_layout) { + return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( + detail::LayoutPatternBaseImpl(), matched_layout); +} + +namespace detail { + +template +class ShapePattern; + +// The base ShapePattern implementation. Matches only if the shape is not +// nullptr. +class ShapePatternBaseImpl { + public: + bool Match(const ::xla::Shape* shape) const { return shape != nullptr; } +}; + +// A ShapePattern implementation that matches only if the shape equals a Shape +// proto. +template +class ShapePatternEqualImpl { + public: + explicit constexpr ShapePatternEqualImpl(const Previous& previous, + const ::xla::Shape* shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + } + + private: + Previous previous_; + const ::xla::Shape* shape_; +}; + +// A ShapePattern implementation that matches only if the shape is compatible to +// a Shape proto. +template +class ShapePatternCompatibleImpl { + public: + explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, + const ::xla::Shape* shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + } + + private: + Previous previous_; + const ::xla::Shape* shape_; +}; + +// A ShapePattern implementation that matches only if the shape has a given +// element type. +template +class ShapePatternElementTypeImpl { + public: + explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, + PrimitiveType element_type) + : previous_(previous), element_type_(element_type) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && shape->element_type() == element_type_; + } + + private: + Previous previous_; + PrimitiveType element_type_; +}; + +// A ShapePattern implementation that matches only if the shape is scalar. +template +class ShapePatternIsScalarImpl { + public: + explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape is an array +template +class ShapePatternIsArrayImpl { + public: + explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape is a tuple. +template +class ShapePatternIsTupleImpl { + public: + explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape has a given +// rank. +template +class ShapePatternRankImpl { + public: + explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) + : previous_(previous), rank_(rank) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + } + + private: + Previous previous_; + int64 rank_; +}; + +// A ShapePattern implementation that matches only if the shape has a layout +// that matches a given pattern. +template +class ShapePatternLayoutImpl { + public: + explicit constexpr ShapePatternLayoutImpl( + const Previous& previous, + const LayoutPattern& layout) + : previous_(previous), layout_(layout) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout()); + } + + bool Match(Shape* shape) const { + return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout()); + } + + private: + Previous previous_; + LayoutPattern layout_; +}; + +// A ShapePattern implementation that matches only if the shape has a subshape +// that matches a given pattern. +template +class ShapePatternSubshapeImpl { + public: + explicit ShapePatternSubshapeImpl( + const Previous& previous, ShapeIndexView index, + const ShapePattern& subshape) + : previous_(previous), index_(index), subshape_(subshape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + } + + bool Match(::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + } + + private: + Previous previous_; + ShapeIndexView index_; + ShapePattern subshape_; +}; + +// A pattern that matches Shapes. +template +class ShapePattern { + public: + explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) + : impl_(impl), matched_shape_(matched_shape) {} + + // Returns true and captures the shape iff it matches the pattern. + bool Match(const ::xla::Shape* shape) const { + if (impl_.Match(shape)) { + if (matched_shape_) { + *matched_shape_ = shape; + } + return true; + } + return false; + } + + // Returns true and captures the shape iff it matches the pattern. + bool Match(::xla::Shape* shape) const { + if (impl_.Match(shape)) { + if (matched_shape_) { + *matched_shape_ = shape; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the shape equals the given proto. + // The layout must outlive the returned pattern. + constexpr ShapePattern> EqualTo( + const ::xla::Shape* shape) const { + return ShapePattern>( + ShapePatternEqualImpl(impl_, shape), matched_shape_); + } + + // Modifies the pattern to match only if the shape is compatible to the given + // proto. The layout must outlive the returned pattern. + constexpr ShapePattern> + CompatibleTo(const ::xla::Shape* shape) const { + return ShapePattern>( + ShapePatternCompatibleImpl(impl_, shape), matched_shape_); + } + + // Modifies the pattern to match only if the shape has the given element type. + constexpr ShapePattern> + WithElementType(PrimitiveType element_type) const { + return ShapePattern>( + ShapePatternElementTypeImpl(impl_, element_type), matched_shape_); + } + + // Modifies the pattern to match only if the shape is scalar. + constexpr ShapePattern> IsScalar() + const { + return ShapePattern>( + ShapePatternIsScalarImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape is an array. + constexpr ShapePattern> IsArray() + const { + return ShapePattern>( + ShapePatternIsArrayImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape is a tuple. + constexpr ShapePattern> IsTuple() + const { + return ShapePattern>( + ShapePatternIsTupleImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape has the given rank. + constexpr ShapePattern> WithRank( + int64 rank) const { + return ShapePattern>( + ShapePatternRankImpl(impl_, rank), matched_shape_); + } + + // Modifies the pattern to match only if the shape has a layout that matches + // the given pattern. + template + constexpr ShapePattern> + WithLayout(const LayoutPattern& layout) const { + return ShapePattern>( + ShapePatternLayoutImpl(impl_, layout), + matched_shape_); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + WithLayoutEqualTo(const ::xla::Layout* layout) const { + return WithLayout(Layout().EqualTo(layout)); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + IsDenseArray(const ::xla::Layout* layout) const { + return WithLayout(Layout().WithDenseFormat()); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + IsSparseArray(const ::xla::Layout* layout) const { + return WithLayout(Layout().WithSparseFormat()); + } + + // Modifies the pattern to match only if the shape has a subshape that matches + // the given pattern. + template + ShapePattern> + WithSubshape(ShapeIndexView index, + const ShapePattern& subshape) const { + return ShapePattern< + ShapeType, ShapePatternSubshapeImpl>( + ShapePatternSubshapeImpl(impl_, index, + subshape), + matched_shape_); + } + + ShapePattern>> + WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { + return WithSubshape(index, + ShapePattern( + ShapePatternBaseImpl(), nullptr) + .EqualTo(shape)); + } + + ShapePattern>> + WithSubshapeCompatibleTo(ShapeIndexView index, + const ::xla::Shape* shape) const { + return WithSubshape(index, + ShapePattern( + ShapePatternBaseImpl(), nullptr) + .CompatibleTo(shape)); + } + + private: + Impl impl_; + ShapeType** matched_shape_; +}; + +} // namespace detail + +// Creates a shape pattern that will capture the matched layout in the argument. +inline constexpr detail::ShapePattern +Shape(const ::xla::Shape** matched_shape = nullptr) { + return detail::ShapePattern( + detail::ShapePatternBaseImpl(), matched_shape); +} + +// Creates a shape pattern that will capture the matched layout in the argument. +inline constexpr detail::ShapePattern<::xla::Shape, + detail::ShapePatternBaseImpl> +Shape(::xla::Shape** matched_shape) { + return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( + detail::ShapePatternBaseImpl(), matched_shape); +} + +namespace detail { + +template +class HloInstructionPattern; + +// The base HloInstructionPattern implementation. Matches only if the +// instruction is not nullptr. +class HloInstructionPatternBaseImpl { + public: + bool Match(const ::xla::HloInstruction* inst) const { + return inst != nullptr; + } +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a given name. +template +class HloInstructionPatternNameImpl { + public: + explicit HloInstructionPatternNameImpl(const Previous& previous, + tensorflow::StringPiece name) + : previous_(previous), name_(name) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->name() == name_; + } + + private: + Previous previous_; + tensorflow::StringPiece name_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a given opcode. +template +class HloInstructionPatternOpcodeImpl { + public: + explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, + HloOpcode opcode, + bool invert) + : previous_(previous), opcode_(opcode), invert_(invert) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + } + + private: + Previous previous_; + HloOpcode opcode_; + bool invert_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a shape that matches a given pattern. +template +class HloInstructionPatternShapeImpl { + public: + explicit constexpr HloInstructionPatternShapeImpl( + const Previous& previous, const ShapePattern& shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && shape_.Match(&inst->shape()); + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + } + + private: + Previous previous_; + ShapePattern shape_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has an operand that matches a given pattern. +template +class HloInstructionPatternOperandImpl { + public: + explicit constexpr HloInstructionPatternOperandImpl( + const Previous& previous, int64 operand_index, + const HloInstructionPattern& operand) + : previous_(previous), operand_index_(operand_index), operand_(operand) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_)); + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_)); + } + + private: + Previous previous_; + int64 operand_index_; + HloInstructionPattern operand_; +}; + +// A pattern that matches HloInstructions. +template +class HloInstructionPattern { + public: + explicit constexpr HloInstructionPattern(const Impl& impl, + HloInstructionType** matched_inst) + : impl_(impl), matched_inst_(matched_inst) {} + + // Returns true and captures the instruction iff it matches the pattern. + bool Match(const ::xla::HloInstruction* inst) const { + if (impl_.Match(inst)) { + if (matched_inst_) { + *matched_inst_ = inst; + } + return true; + } + return false; + } + + // Returns true and captures the instruction iff it matches the pattern. + bool Match(::xla::HloInstruction* inst) const { + if (impl_.Match(inst)) { + if (matched_inst_) { + *matched_inst_ = inst; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the instruction has the given name. + HloInstructionPattern> + WithName(tensorflow::StringPiece name) const { + return HloInstructionPattern>( + HloInstructionPatternNameImpl(impl_, name), matched_inst_); + } + + // Modifies the pattern to match only if the instruction has the given opcode. + constexpr HloInstructionPattern> + WithOpcode(HloOpcode opcode) const { + return HloInstructionPattern>( + HloInstructionPatternOpcodeImpl(impl_, opcode, false), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction does not have the + // given opcode. + constexpr HloInstructionPattern> + WithoutOpcode(HloOpcode opcode) const { + return HloInstructionPattern>( + HloInstructionPatternOpcodeImpl(impl_, opcode, true), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction is a constant. + constexpr HloInstructionPattern> + IsConstant() const { + return WithOpcode(HloOpcode::kConstant); + } + + // Modifies the pattern to match only if the instruction is not a constant. + constexpr HloInstructionPattern> + IsNonConstant() const { + return WithoutOpcode(HloOpcode::kConstant); + } + + // Modifies the pattern to match only if the instruction has a shape that + // matches the given pattern. + template + constexpr HloInstructionPattern< + HloInstructionType, + HloInstructionPatternShapeImpl> + WithShape(const ShapePattern& shape) const { + return HloInstructionPattern< + HloInstructionType, + HloInstructionPatternShapeImpl>( + HloInstructionPatternShapeImpl(impl_, + shape), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction has an operand that + // matches the given pattern. + template + constexpr HloInstructionPattern< + HloInstructionType, + HloInstructionPatternOperandImpl> + WithOperand( + int64 operand_index, + const HloInstructionPattern& operand) const { + return HloInstructionPattern< + HloInstructionType, + HloInstructionPatternOperandImpl>( + HloInstructionPatternOperandImpl( + impl_, operand_index, operand), + matched_inst_); + } + + private: + Impl impl_; + HloInstructionType** matched_inst_; +}; + +} // namespace detail + +// Creates an instruction pattern that will capture the matched instruction in +// the argument. +inline constexpr detail::HloInstructionPattern< + const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> +Op(const ::xla::HloInstruction** matched_inst = nullptr) { + return detail::HloInstructionPattern( + detail::HloInstructionPatternBaseImpl(), matched_inst); +} + +// Creates an instruction pattern that will capture the matched instruction in +// the argument. +inline constexpr detail::HloInstructionPattern< + ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> +Op(::xla::HloInstruction** matched_inst) { + return detail::HloInstructionPattern<::xla::HloInstruction, + detail::HloInstructionPatternBaseImpl>( + detail::HloInstructionPatternBaseImpl(), matched_inst); +} + +// Helpers for nullary instructions. +#define XLA_NULLOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst) \ + ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ + return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ + } +XLA_NULLOP_PATTERN(Constant) +XLA_NULLOP_PATTERN(Infeed) +XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Recv) +#undef XLA_NULLOP_PATTERN + +// Helpers for unary instructions. +#define XLA_UNOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Arg&& arg)->decltype( \ + Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } +XLA_UNOP_PATTERN(Abs) +XLA_UNOP_PATTERN(RoundNearestAfz) +XLA_UNOP_PATTERN(Bitcast) +XLA_UNOP_PATTERN(Broadcast) +XLA_UNOP_PATTERN(BroadcastDimOne) +XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Copy) +XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(Exp) +XLA_UNOP_PATTERN(Fft) +XLA_UNOP_PATTERN(Floor) +XLA_UNOP_PATTERN(Imag) +XLA_UNOP_PATTERN(IsFinite) +XLA_UNOP_PATTERN(Log) +XLA_UNOP_PATTERN(Not) +XLA_UNOP_PATTERN(Negate) +XLA_UNOP_PATTERN(Outfeed) +XLA_UNOP_PATTERN(Real) +XLA_UNOP_PATTERN(Reduce) +XLA_UNOP_PATTERN(ReducePrecision) +XLA_UNOP_PATTERN(Reshape) +XLA_UNOP_PATTERN(Reverse) +XLA_UNOP_PATTERN(Send) +XLA_UNOP_PATTERN(Sign) +XLA_UNOP_PATTERN(Sin) +XLA_UNOP_PATTERN(Sort) +XLA_UNOP_PATTERN(Tanh) +XLA_UNOP_PATTERN(Transpose) +#undef XLA_UNOP_PATTERN + +// Helpers for binary instructions. +#define XLA_BINOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } +XLA_BINOP_PATTERN(Add) +XLA_BINOP_PATTERN(Atan2) +XLA_BINOP_PATTERN(Divide) +XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Dot) +XLA_BINOP_PATTERN(Eq) +XLA_BINOP_PATTERN(Gather) +XLA_BINOP_PATTERN(Ge) +XLA_BINOP_PATTERN(Gt) +XLA_BINOP_PATTERN(Le) +XLA_BINOP_PATTERN(Lt) +XLA_BINOP_PATTERN(Maximum) +XLA_BINOP_PATTERN(Minimum) +XLA_BINOP_PATTERN(Multiply) +XLA_BINOP_PATTERN(Ne) +XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(Remainder) +XLA_BINOP_PATTERN(Subtract) +XLA_BINOP_PATTERN(And) +XLA_BINOP_PATTERN(Or) +XLA_BINOP_PATTERN(ShiftLeft) +XLA_BINOP_PATTERN(ShiftRightArithmetic) +XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_BINOP_PATTERN + +// Helpers for ternary instructions. +#define XLA_TERNOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ + ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ + Arg1&& arg1, Arg2&& arg2) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2)); \ + } +XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Select); +#undef XLA_TERNOP_PATTERN + +// Helpers for matching non-constant instructions. +inline auto NonConstant() -> decltype(Op().IsNonConstant()) { + return Op().IsNonConstant(); +} + +template +inline auto NonConstant(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsNonConstant()) { + return Op(matched_inst).IsNonConstant(); +} + +} // namespace match + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5291b1437afc67312382fe52bf9a66a1843b1b4c --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(PatternMatcherTest, AddOp) { + constexpr char kModuleStr[] = R"(HloModule two_plus_two_module + ENTRY %two_plus_two_computation () -> f32[] { + %two = f32[] constant(2) + ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + + const HloInstruction* matched_inst; + HloInstruction* matched_operand; + Shape* matched_shape; + Layout* matched_layout; + + ASSERT_TRUE(Match( + hlo_module->entry_computation()->root_instruction(), + match::Op(&matched_inst) + .WithName("two_plus_two") + .WithOpcode(HloOpcode::kAdd) + .WithShape( + match::Shape(&matched_shape) + .WithLayout(match::Layout(&matched_layout).WithDenseFormat())) + .WithOperand( + 0, + match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant)))); + ASSERT_NE(matched_inst, nullptr); + EXPECT_EQ(matched_inst->name(), "two_plus_two"); + EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd); + + EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(), + match::Add(match::Constant(), match::Constant()))); + + EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(), + match::Op().WithName("bad_name"))); + matched_inst = nullptr; + EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(), + match::Multiply(&matched_inst, match::Op(), match::Op()))); +} + +TEST(PatternMatcherTest, ScalarShape) { + auto scalar_shape = ShapeUtil::MakeShape(F32, {}); + Shape* matched_shape; + EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar())); + EXPECT_EQ(matched_shape, &scalar_shape); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray())); + EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0))); + EXPECT_FALSE(Match( + &scalar_shape, + match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32))); +} + +TEST(PatternMatcherTest, ArrayShape) { + auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4}); + Shape* matched_shape; + EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); + EXPECT_EQ(matched_shape, &array_shape); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3))); + EXPECT_FALSE( + Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); + Layout* matched_layout; + EXPECT_FALSE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithSparseFormat()))); +} + +TEST(PatternMatcherTest, TupleShape) { + auto tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1, 2, 3}), + ShapeUtil::MakeShape(S32, {4, 5}), + }); + EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple())); + EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray())); + EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar())); + + Shape* subshape; + ASSERT_TRUE(Match( + &tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3)))); + ASSERT_NE(subshape, nullptr); + EXPECT_TRUE( + ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0}))); + EXPECT_TRUE(Match(&tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {0}))))); + EXPECT_FALSE(Match(&tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {1}))))); + + ASSERT_TRUE(Match( + &tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2)))); + ASSERT_NE(subshape, nullptr); + EXPECT_TRUE( + ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1}))); + EXPECT_TRUE(Match(&tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {1}))))); + EXPECT_FALSE(Match(&tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {0}))))); + + EXPECT_FALSE( + Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape()))); + EXPECT_FALSE( + Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index e62bafc50b0e1270702621c9ea7b2ee43e001fe0..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -53,8 +53,8 @@ bool IsReshapeOrTranspose(const HloInstruction* instruction) { instruction->opcode() == HloOpcode::kTranspose; } -// Returns true iff `instruction` can change its shape simply by adjusting -// metadata. +// Returns true if `instruction` can change its shape simply by adjusting +// metadata or if `instruction` is a broadcast of a scalar value. bool CanTriviallyChangeShape(const HloInstruction* instruction) { // NOTE: Technically a sequence of reshape(reshape(constant)) is also // trivially reshapable, so we might be tempted to simply recurse if @@ -88,19 +88,31 @@ bool CanTriviallyChangeShape(const HloInstruction* instruction) { instruction->user_count() == 1) { return true; } + + // A broadcase of scalar can trivially change its shape. + if (instruction->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(instruction->operand(0)->shape())) { + return true; + } + return false; } -// Finds the first non-scalar operand of an instruction that is a non-trivial -// reshape or transpose. Returns the operand if it is found or nullptr if not -// found. +// Returns true iff `instruction` is a reshape/transpose instruction for which +// a shape change is nontrivial. +bool IsNontrivialReshape(const HloInstruction* instruction) { + return !ShapeUtil::IsScalar(instruction->shape()) && + IsReshapeOrTranspose(instruction) && + !CanTriviallyChangeShape(instruction->operand(0)); +} + +// Finds the first operand of an instruction that is a non-trivial reshape or +// transpose. Returns such an operand or nullptr if not found. HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { - if (!ShapeUtil::IsScalar(operand->shape()) && - IsReshapeOrTranspose(operand) && - !CanTriviallyChangeShape(operand->operand(0))) { - VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " + if (IsNontrivialReshape(operand)) { + VLOG(5) << "Found first non-trivial reshape operand of " << hlo->ToString(HloPrintOptions().set_print_metadata(false)) << ":\n\t" << operand->ToString(HloPrintOptions().set_print_metadata(false)); @@ -110,7 +122,7 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( return nullptr; } -// Returns whether `a` and `b` are equivalent for the purposes of this pass. +// Returns whether `a` and `b` are equivalent reshapes/transposes. bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { if (a->opcode() != b->opcode() || !ShapeUtil::SameDimensions(a->shape(), b->shape())) { @@ -127,71 +139,14 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if all operands of `instruction` can easily change shape. -// Operands can easily change shape if they are all reshapes/transposes to and -// from the same shape. Additionally, operands like constant, rng, and any -// scalar change shape with only an adjustment of metadata. -bool AllOperandsHaveEasyShapeChanges( - const HloInstruction* instruction, - const HloInstruction* first_reshape_operand) { - auto print_no_metadata = HloPrintOptions().set_print_metadata(false); - VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - // Check whether all operands: - // 0. Have the same dimensions as the output -- if not, it may be - // implicitly broadcast, which can confound the movement's - // correctness. - // - // And one of the following: - // 1. Are reshapes or transposes that have the same input and - // output shapes as all other reshaped or transposed operands. - // or - // 2. Are one of kConstant, kRng, and scalars that can change shape - // trivially, - for (const HloInstruction* operand : instruction->operands()) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToString(print_no_metadata) << "\n\tinstruction: " - << instruction->ToString(print_no_metadata); - return false; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToString(print_no_metadata) - << "\n\toperand: " << operand->ToString(print_no_metadata); - continue; - } - - if (CanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToString(print_no_metadata); - continue; - } - - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " - << operand->ToString(print_no_metadata); - return false; - } - - VLOG(3) << "All operands have easy shape changes: " - << instruction->ToString(print_no_metadata); - return true; -} - // This function is called once we've decided to sink reshape/transpose operands // across an instruction. It returns an updated `operand` with a shape that // plays nicely with `new_operand_shape`; either it has the same shape (of the // correct type), or it is a scalar that may be implicitly broadcast. -HloInstruction* UpdateOperand(HloComputation* computation, - const HloInstruction* first_reshape_operand, +HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, const Shape& new_operand_shape, HloInstruction* operand) { + HloComputation* computation = operand->parent(); const PrimitiveType element_type = operand->shape().element_type(); const Shape new_shape = ShapeUtil::ChangeElementType(new_operand_shape, element_type); @@ -222,36 +177,24 @@ HloInstruction* UpdateOperand(HloComputation* computation, VLOG(5) << "Using existing operand of kReshape or kTranspose"; return operand->mutable_operand(0); } + case HloOpcode::kBroadcast: { + CHECK(ShapeUtil::IsScalar(operand->operand(0)->shape())); + HloInstruction* inst = computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + VLOG(5) << "Changing broadcast from " << operand->ToString() << " to " + << inst->ToString(); + return inst; + } + default: LOG(FATAL) << "Unexpected operand opcode during update: " << operand; } } -// Try to sink any reshape or transpose operands of `instruction` across it. We -// do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. -StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, - HloInstruction* instruction) { - // Only perform sinks for live elementwise instructions with operands. - const bool is_dead = instruction->user_count() == 0 && - instruction != computation->root_instruction(); - if (!instruction->IsElementwise() || instruction->operands().empty() || - is_dead) { - return false; - } - - // Only perform sinks if there are any nontrivial reshape/transpose operands. - const HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - if (!first_reshape_operand) { - return false; - } - - // Only perform sinks if all operands can easily change shape. - if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { - return false; - } - +// Actually performs the reshape-move transformation -- that is, sinks the +// reshape or transpose operands of `instruction` across it. +StatusOr PerformSinkReshapeOrTranspose( + HloInstruction* instruction, const HloInstruction* first_reshape_operand) { auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); @@ -272,8 +215,8 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, } VLOG(3) << "Updating operand #" << i << ": " << operands[i]->ToString(print_no_metadata); - operands[i] = UpdateOperand(computation, first_reshape_operand, - new_operand_shape, operands[i]); + operands[i] = + UpdateOperand(first_reshape_operand, new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is @@ -285,6 +228,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, *shape->mutable_layout() = new_operand_shape.layout(); } } + HloComputation* computation = instruction->parent(); HloInstruction* new_elementwise = computation->AddInstruction(instruction->CloneWithNewOperands( // `instruction` may change the element type, e.g., from @@ -319,6 +263,141 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, return true; } +// Returns true if the instruction is a reshape-move candidate. +// +// An instruction is a reshape-move candidate if the instruction is elementwise, +// has at least one nontrivial reshape/transpose operand, and its operands are +// either trivially reshapable or are equivalent nontrivial reshapes/transposes. +bool IsReshapeMoveCandidate(HloInstruction* instruction) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + VLOG(5) << "** Checking instruction: " + << instruction->ToString(print_no_metadata); + + // Only perform reshape-move for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != instruction->parent()->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { + return false; + } + + // Check whether all operands: + // 0. Have the same dimensions as the output -- if not, they may be + // implicitly broadcast, which can confound the movement's + // correctness. + // + // And one of the following: + // 1. Are reshapes or transposes that have the same input and + // output shapes as all other reshaped or transposed operands. + // or + // 2. Are one of kConstant, kRng, broadcast of a scalar value, and scalars + // that can change shape trivially. + const HloInstruction* first_reshape_operand = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); + return false; + } + + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " + << operand->ToString(print_no_metadata); + continue; + } + + if (!IsNontrivialReshape(operand)) { + VLOG(5) << "Operand can't trivially change shape: " + << operand->ToString(print_no_metadata); + return false; + } + + if (first_reshape_operand == nullptr) { + first_reshape_operand = operand; + VLOG(5) << "First reshape operand " + << operand->ToString(print_no_metadata); + } else if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) + << "Operand is an equivalent reshape of the first reshape operand " + << operand->ToString(print_no_metadata); + } else { + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is a reshape but is not equivalent to the first " + "Reshape operand" + << operand->ToString(print_no_metadata); + return false; + } + } + + if (first_reshape_operand) { + VLOG(5) << "All operands have easy shape changes: " + << instruction->ToString(print_no_metadata); + } + + return first_reshape_operand != nullptr; +} + +// Reshape-moves all qualifying instructions in reshape_candidates. Returns +// true if it makes changes. +// +// `reshape_candidates` is a set of HloInstructions with nontrivial reshape +// operands, and a instruction in the set can be reshape-moved iff all the users +// of its nontrivial reshape operands can also be reshaped-moved. +// +// The algorithm here iteratively finds the nontrivial operands with users that +// are outside the set of `reshape_candidates`, and removes their users from +// `reshape_candidates`, until either `reshape_candidates` becomes empty or none +// of the remaining nontrivial operands have users outside `reshape_candidates`. +// In the later case, all the remaining instructions in `reshape_candidates` +// are reshape-moved and the routine returns true. +StatusOr TryReshapeMoveOnCandidates( + HloInstructionSet* reshape_candidates) { + bool removed = true; + while (!reshape_candidates->empty() && removed) { + if (VLOG_IS_ON(5)) { + for (const HloInstruction* instruction : *reshape_candidates) { + VLOG(5) << "candidate " << instruction->ToString(); + } + } + ConstHloInstructionSet nontrivial_operands; + for (const HloInstruction* instruction : *reshape_candidates) { + for (const auto* operand : instruction->operands()) { + if (IsNontrivialReshape(operand)) { + nontrivial_operands.insert(operand); + } + } + } + + removed = false; + for (auto operand : nontrivial_operands) { + if (c_any_of(operand->users(), [&](HloInstruction* user) { + return !reshape_candidates->count(user); + })) { + for (auto* user : operand->users()) { + removed |= reshape_candidates->erase(user) > 0; + } + } + } + } + + if (reshape_candidates->empty()) { + return false; + } + for (HloInstruction* instruction : *reshape_candidates) { + const HloInstruction* first_reshape_operand = + FirstNonScalarAndNonTrivialReshapeOperand(instruction); + TF_ASSIGN_OR_RETURN( + bool did_change, + PerformSinkReshapeOrTranspose(instruction, first_reshape_operand)); + CHECK(did_change); + } + return true; +} + } // namespace StatusOr ReshapeMover::Run(HloModule* module) { @@ -326,11 +405,15 @@ StatusOr ReshapeMover::Run(HloModule* module) { VLOG(2) << "Pre ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); for (auto* comp : module->MakeNonfusionComputations()) { - for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool did_change, - TrySinkReshapeOrTranspose(comp, instruction)); - changed |= did_change; + HloInstructionSet reshape_candidates; + for (HloInstruction* instruction : comp->instructions()) { + if (IsReshapeMoveCandidate(instruction)) { + reshape_candidates.insert(instruction); + } } + TF_ASSIGN_OR_RETURN(bool did_change, + TryReshapeMoveOnCandidates(&reshape_candidates)); + changed |= did_change; } VLOG(2) << "Post ReshapeMover HLO:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index aac8638a54f744f0c230ec6c5ca071c1daf45ab2..094f7319f462a71f4bfe972771a1de4aedbb8ee3 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -560,5 +560,95 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); } +TEST_F(ReshapeMoverTest, SinkTransposeAcrossBroadcastScalar) { + const string hlo_string = R"( + HloModule TransposeMulInversedTransposeModule + ENTRY TransposeMulInversedTranspose { + src0 = f32[20,8]{1,0} parameter(0) + transpose0 = f32[8,20]{1,0} transpose(src0), dimensions={1,0} + src1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(src1), dimensions={} + ROOT multiply0 = f32[8,20]{1,0} multiply(transpose0, broadcast0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Transpose(op::Multiply())); +} + +TEST_F(ReshapeMoverTest, ReshapeWithUsersOutsideCandidatesNotSink) { + const string hlo_string = R"( + HloModule ReshapeWithUsersOutsideCandidates + ENTRY ReshapeWithMultipleUsers { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + param4 = f32[8,20]{1,0} parameter(4) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, param4) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_FALSE(changed); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink1) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates1 + ENTRY ReshapeWithMultipleUsers1 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + param1 = f32[] parameter(1) + broadcast0 = f32[8,20]{1,0} broadcast(param1), dimensions={} + param2 = f32[20,8]{1,0} parameter(2) + reshape1 = f32[8,20]{1,0} reshape(param2) + param3 = f32[20,8]{1,0} parameter(3) + reshape2 = f32[8,20]{1,0} reshape(param3) + add0 = f32[8,20]{1,0} add(reshape0, broadcast0) + add1 = f32[8,20]{1,0} add(reshape0, reshape1) + add2 = f32[8,20]{1,0} add(reshape1, reshape2) + ROOT tuple = (f32[8,20]{1,0},f32[8,20]{1,0}, + f32[8,20]{1,0}) tuple(add0, add1, add2) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Tuple(op::Reshape(), op::Reshape(), op::Reshape())); +} + +TEST_F(ReshapeMoverTest, ReshapeNoUsersOutsideCandidatesSink2) { + const string hlo_string = R"( + HloModule ReshapeNoUsersOutsideCandidates2 + ENTRY ReshapeWithMultipleUsers2 { + param0 = f32[20,8]{1,0} parameter(0) + reshape0 = f32[8,20]{1,0} reshape(param0) + ROOT add0 = f32[8,20]{1,0} add(reshape0, reshape0) + } + )"; + + ParseAndVerifyModule(hlo_string); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ReshapeMover().Run(&module())); + EXPECT_TRUE(changed); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Reshape(op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 98dfc89867ab33788c4cc837a66d6751a1ef2507..52500e4e79042c51d4bea17dea6845ed23433d6c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -231,10 +232,14 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr>> +Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - int device_ordinal) { - std::vector shaped_buffers; + tensorflow::gtl::ArraySlice + stream_executors) { + CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); + std::vector> replicated_arguments; + replicated_arguments.resize(options_.number_of_replicas()); for (size_t i = 0; i < arguments.size(); ++i) { auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); if (!buffer_status.ok()) { @@ -242,29 +247,32 @@ StatusOr> Service::ResolveAndValidateArguments( StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); - - // Verify allocation is same platform and device as the execution. - if (shaped_buffer->platform() != execute_backend_->platform() || - shaped_buffer->device_ordinal() != device_ordinal) { - return InvalidArgument( - "argument %lu is on device %s:%d but computation will be executed " - "on device %s", - i, shaped_buffer->platform()->Name().c_str(), - shaped_buffer->device_ordinal(), - execute_backend_->device_name(device_ordinal).c_str()); + auto replicated_buffers = buffer_status.ValueOrDie(); + CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size()); + for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { + const ShapedBuffer* shaped_buffer = replicated_buffers[replica]; + int replica_device_ordinal = stream_executors[replica]->device_ordinal(); + // Verify allocation is same platform and device as the execution. + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != replica_device_ordinal) { + return InvalidArgument( + "argument %lu is on device %s:%d but computation will be executed " + "on device %s", + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(replica_device_ordinal).c_str()); + } + replicated_arguments[replica].push_back(shaped_buffer); } - - shaped_buffers.push_back(shaped_buffer); } - return shaped_buffers; + return replicated_arguments; } StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -278,8 +286,15 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { + if (user_computation == nullptr) { + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", + i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), + ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + } return InvalidParameterArgument( - *user_computation.ParameterMetadata(i).value(), + *user_computation->ParameterMetadata(i).value(), "Argument does not match shape of computation parameter %d: want %s, " "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), @@ -306,8 +321,6 @@ StatusOr> Service::CreateModuleConfig( if (execution_options != nullptr) { config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); - config->enable_hlo_profiling( - execution_options->debug_options().xla_hlo_profile()); } else { config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); } @@ -324,7 +337,7 @@ StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation) { + const UserComputation* user_computation) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); @@ -396,6 +409,37 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } +StatusOr>> Service::BuildExecutables( + const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf("BuildExecutable on service %p", this); + + VLOG(1) << "Computations:"; + for (const HloModuleProto* proto : module_protos) { + VLOG(1) << proto->name(); + } + + CHECK_EQ(module_protos.size(), module_configs.size()); + std::vector> modules; + for (int64 i = 0; i < module_protos.size(); ++i) { + const HloModuleProto* proto = module_protos[i]; + const HloModuleConfig& config = *module_configs[i]; + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(*proto, config)); + modules.push_back(std::move(module)); + } + + TF_ASSIGN_OR_RETURN( + std::vector> executables, + backend->compiler()->Compile(std::move(modules), std::move(executors), + device_allocator)); + + return std::move(executables); +} + StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, @@ -489,7 +533,8 @@ StatusOr> Service::BuildAndCacheExecutable( StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice> arguments, + tensorflow::gtl::ArraySlice>> + arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { @@ -512,6 +557,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + std::vector> result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(replicas[replica])); @@ -544,23 +591,20 @@ Service::ExecuteParallelAndRegisterResult( backend->StreamBorrower()); // Asynchronously launch the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr result, - executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + executables[i]->ExecuteAsyncOnStream( + &run_options, arguments[i][replica])); if (replica == 0 && profile != nullptr) { streams.back()->ThenStopTimer(timers.back().get()); } - // All replicas share the same device address for the result allocation, - // so only one of the replicas need to register the result handle. - if (replica == 0) { - TF_ASSIGN_OR_RETURN( - GlobalDataHandle handle, - allocation_tracker_.Register(std::move(result), result_tags[i])); - result_handles.push_back(handle); - } + result_buffers.emplace_back(std::move(result)); } + TF_ASSIGN_OR_RETURN(GlobalDataHandle handle, + allocation_tracker_.RegisterReplicatedBuffers( + std::move(result_buffers), result_tags[i])); + result_handles.push_back(handle); } // Wait for all executions to complete. @@ -626,9 +670,9 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - const string& result_tag, ExecutionProfile* profile) { + const tensorflow::gtl::ArraySlice> + arguments, + Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector::SmartPtr> streams; @@ -661,21 +705,26 @@ StatusOr Service::ExecuteAndRegisterResult( backend->inter_op_thread_pool()); } - std::unique_ptr result; if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN(result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); - } else { - // TODO(b/69985541): Support profiling also on this path. - std::vector> - repeated_arguments(options_.number_of_replicas(), arguments); + TF_ASSIGN_OR_RETURN( + auto result, executable->ExecuteOnStreamWrapper(&run_options[0], + profile, arguments[0])); + return allocation_tracker_.Register(std::move(result), result_tag); + } - TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( - run_options, repeated_arguments)); - TF_RET_CHECK(!results.empty()); - result = std::move(results[0]); + // TODO(b/69985541): Support profiling also on this path. + + std::vector> + replicated_arguments; + for (const auto& arg : arguments) { + replicated_arguments.emplace_back(arg); } - return allocation_tracker_.Register(std::move(result), result_tag); + + TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( + run_options, replicated_arguments)); + TF_RET_CHECK(!results.empty()); + return allocation_tracker_.RegisterReplicatedBuffers(std::move(results), + result_tag); } tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, @@ -685,11 +734,52 @@ tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, return computation->SetReturnValue(arg->operand()); } +StatusOr> +Service::GetExecutors(const ExecutionOptions& execution_options, + int64 requests_size, int64 request_index) const { + if (execution_options.device_handles().empty()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); + } + if (requests_size > 1 && execution_options.device_handles_size() > 1) { + return InvalidArgument( + "Parallel requests with multiple device handles is not supported. " + "Found %lld parallel requests, with request %lld containing %d device " + "handles.", + requests_size, request_index, execution_options.device_handles_size()); + } + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } + return executors; +} + +StatusOr>> Service::GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments) { + // Resolve the allocations for the arguments of the computation, and create + // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, execution_options.device_handles(0))); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arguments, replicas)); + return replicated_arguments; +} + tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - std::vector> all_arguments; + std::vector>> all_arguments; std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; @@ -713,18 +803,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // is one of the executors to run the replicated computation. const ExecutionOptions& execution_options = arg->requests(i).execution_options(); - if (execution_options.device_handles().empty()) { - return FailedPrecondition( - "device handles must be given to execute parallel computations"); - } - std::vector executors; - for (const auto& device_handle : execution_options.device_handles()) { - TF_ASSIGN_OR_RETURN(auto replicas, - Replicas(*execute_backend_, device_handle)); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); - executors.push_back(executor); - } + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -741,27 +823,24 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); - // Resolve the allocations for the arguments of the computation, and create - // a vector of device memory offsets for the arguments from the allocations. - // In the case of partitioned computations, assume all arguments go on the - // zeroth core. - TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(request.arguments(), - executors[0]->device_ordinal())); + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, - request.execution_options(), *user_computation)); + CreateModuleConfig(*program_shape, replicated_arguments.front(), + request.execution_options(), user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {}); + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); versioned_handles.push_back(versioned_handle); module_configs.push_back(std::move(module_config)); computation_names.insert(computation_names.end(), executors.size(), @@ -807,6 +886,107 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteGraphParallel( + const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { + VLOG(1) << "running execute-graph-parallel request"; + + std::vector>> all_arguments; + std::vector> all_executors; + std::vector module_protos; + std::vector> module_configs; + std::vector computation_names; + std::vector device_handles; + + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteGraphRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > + execute_backend_->device_count()) { + return FailedPrecondition( + "there are not enough stream executors to execute %d computations", + num_requested_devices); + } + + for (int64 i = 0; i < arg->requests_size(); ++i) { + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + const ExecuteGraphRequest& request = arg->requests(i); + TF_RET_CHECK(request.has_computation()) << "computations may not be empty"; + TF_RET_CHECK(request.computation().has_program_shape()) + << "programe shape may not be empty"; + + // Get the executors. + TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, + arg->requests_size(), i)); + + // Get the replicated arguments. + TF_ASSIGN_OR_RETURN(auto replicated_arguments, + GetArguments(execution_options, request.arguments())); + + // Create an HloModuleConfig object for the computation, given the shape of + // the program and the argument allocations. Here, we care only about the + // shapes of the arguments, so, it is sufficient to use the arguments of + // replica 0. + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(request.computation().program_shape(), + replicated_arguments.front(), + request.execution_options(), + /*user_computation=*/nullptr)); + VLOG(3) + << "ExecuteGraphParallel created HloModuleConfig computation layout: " + << module_config->entry_computation_layout().ToString(); + + // Adds to the vectors to build and execute the computations after the loop. + all_arguments.push_back(replicated_arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); + module_protos.push_back(&request.computation()); + module_configs.push_back(std::move(module_config)); + computation_names.insert(computation_names.end(), executors.size(), + request.computation().name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); + } + + // Build the HloModules and compile to generate the executables. + // + // TODO(jlebar): There's currently no way to pass a device allocator to + // ExecuteGraphParallel, so we have to pass a null device_allocator below. + TF_ASSIGN_OR_RETURN(std::vector> executables, + BuildExecutables(module_protos, std::move(module_configs), + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); + std::vector executable_ptrs; + executable_ptrs.reserve(executables.size()); + for (const auto& executable : executables) { + executable_ptrs.push_back(executable.get()); + } + + // Execute the generated executables in parallel and return the device + // handles for each computation's output. + ExecutionProfile profile; + TF_ASSIGN_OR_RETURN( + std::vector outputs, + ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, + execute_backend_.get(), device_handles, + computation_names, &profile)); + for (const GlobalDataHandle& output : outputs) { + ExecuteResponse response; + *response.mutable_output() = output; + *response.mutable_profile() = profile; + *result->add_responses() = response; + } + + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; + return tensorflow::Status::OK(); +} + tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); @@ -831,6 +1011,47 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return tensorflow::Status::OK(); } +tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + ExecuteGraphParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteGraphParallel(¶llel_arg, ¶llel_result)); + return PickParallelResponse(parallel_result, result); +} + +tensorflow::Status Service::PickParallelResponse( + const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { + // The "result device" selection is a bit hacky, but better than assuming it + // is device 0. We have b/76035356 for restructuring the client API to clean + // up the current asymmetries and support more functionalities. + for (int64 i = 0; i < parallel_result.responses_size(); ++i) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, + allocation_tracker_.ResolveForReplica( + parallel_result.responses(i).output(), 0)); + const Shape& shape = buffer->on_host_shape(); + if (!ShapeUtil::IsEmptyTuple(shape)) { + *result = parallel_result.responses(i); + VLOG(3) << "Fetching result from device " << i << ": " + << ShapeUtil::HumanString(shape); + return Status::OK(); + } + } + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + VLOG(1) << "Defaulting to device 0 result"; + return Status::OK(); +} + tensorflow::Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); @@ -847,28 +1068,25 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, // If we received multiple device handles, we must partition the module. if (arg->execution_options().device_handles_size() > 1) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - TF_RET_CHECK(parallel_result.responses_size() > 0); - *result = parallel_result.responses(0); - return Status::OK(); + return ExecuteOneToN(arg, result); } TF_ASSIGN_OR_RETURN( std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(arg->arguments(), - execute_backend_->default_device_ordinal())); + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + // Since we care only about the shapes of the arguments, it is sufficient to + // use the arguments of replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options(), - *user_computation)); + CreateModuleConfig(*program_shape, replicated_arguments.front(), + arg->execution_options(), user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -884,20 +1102,21 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, executable->session_module()->set_execution_platform( execute_backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments( - arguments, execute_backend_->default_stream_executor(), + replicated_arguments.front(), + execute_backend_->default_stream_executor(), execute_backend_->transfer_manager(), executable->session_module())); } TF_ASSIGN_OR_RETURN( *result->mutable_output(), ExecuteAndRegisterResult( - executable.get(), arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), + executable.get(), replicated_arguments, execute_backend_.get(), "result of " + user_computation->name(), result->mutable_profile())); if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, - allocation_tracker_.Resolve(result->output())); + TF_ASSIGN_OR_RETURN( + const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(result->output(), 0)); TF_RETURN_IF_ERROR(RecordResult( *result_buffer, execute_backend_->default_stream_executor(), execute_backend_->transfer_manager(), executable->session_module())); @@ -908,6 +1127,74 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return tensorflow::Status::OK(); } +StatusOr> Service::BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { + VLOG(1) << Printf( + "BuildExecutable on service %p with serialized module proto: %s", this, + module_proto.name().c_str()); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(module_proto, *module_config)); + + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); + + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); + + return std::move(executable); +} + +tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + VLOG(1) << "running execute-graph request"; + + if (!arg->has_computation()) { + return InvalidArgument("computations may not be empty"); + } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("programe shape may not be empty"); + } + + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + return ExecuteOneToN(arg, result); + } + + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_ASSIGN_OR_RETURN( + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, + CreateModuleConfig(arg->computation().program_shape(), + replicated_arguments.front(), + arg->execution_options())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + BuildExecutable(arg->computation(), std::move(module_config), + execute_backend_.get(), + execute_backend_->default_stream_executor(), + /*device_allocator=*/nullptr)); + + TF_ASSIGN_OR_RETURN( + *result->mutable_output(), + ExecuteAndRegisterResult( + executable.get(), replicated_arguments, execute_backend_.get(), + "result of " + arg->computation().name(), result->mutable_profile())); + + VLOG(1) << "successfully completed 'execute-graph' request"; + return tensorflow::Status::OK(); +} + tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) { VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); @@ -925,15 +1212,17 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); TF_ASSIGN_OR_RETURN( - std::vector arguments, - ResolveAndValidateArguments(arg->arguments(), - execute_backend_->default_device_ordinal())); + std::vector> replicated_arguments, + ResolveAndValidateArguments(arg->arguments(), replicas)); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arguments, arg->execution_options(), - *user_computation)); + CreateModuleConfig(*program_shape, replicated_arguments.front(), + arg->execution_options(), user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -946,21 +1235,17 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, versioned_handle, std::move(module_config), execute_backend_.get(), execute_backend_->default_stream_executor(), &profile)); - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); } - std::unique_ptr result_buffer; - for (const Pool::SmartPtr& stream : streams) { + std::vector> result_buffers; + for (size_t i = 0; i < streams.size(); ++i) { + const auto& stream = streams[i]; ExecutableRunOptions options; options.set_stream(stream.get()); options.set_allocator(execute_backend_->memory_allocator()); @@ -971,20 +1256,17 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, ServiceExecutableRunOptions service_options( options, execute_backend_->StreamBorrower()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr this_result_buffer, - executable->ExecuteAsyncOnStream(&service_options, arguments)); + TF_ASSIGN_OR_RETURN(std::unique_ptr this_result_buffer, + executable->ExecuteAsyncOnStream( + &service_options, replicated_arguments[i])); - // Take the first result. - if (result_buffer == nullptr) { - result_buffer = std::move(this_result_buffer); - } + result_buffers.emplace_back(std::move(this_result_buffer)); } TF_ASSIGN_OR_RETURN( GlobalDataHandle output, - allocation_tracker_.Register(std::move(result_buffer), - "result of " + user_computation->name())); + allocation_tracker_.RegisterReplicatedBuffers( + std::move(result_buffers), "result of " + user_computation->name())); *result->mutable_execution() = execution_tracker_.Register( execute_backend_.get(), std::move(streams), profile, output); @@ -1012,7 +1294,7 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, - allocation_tracker_.Resolve(arg->data())); + allocation_tracker_.ResolveForReplica(arg->data(), 0)); const Shape* return_shape; if (arg->has_shape_with_layout()) { @@ -1073,37 +1355,24 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // All memory allocation is done on the first replica. The allocations in all - // other replicas mirror the firsts'. - int master_device_ordinal = replicas[0]->device_ordinal(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr shaped_buffer, - execute_backend_->transfer_manager()->AllocateShapedBuffer( - shape, execute_backend_->memory_allocator(), master_device_ordinal)); - - // Transfer the data to the replicas. + // Allocate memory in each replica and transfer the data to all replicas. + std::vector> replicated_buffers; for (se::StreamExecutor* executor : replicas) { - if (executor->device_ordinal() == master_device_ordinal) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *shaped_buffer)); - } else { - // The replica is not the master. Create an cloned shaped buffer with - // the replica's device ordinal. This is required because - // TransferLiteralToDevice verifies that the device ordinal of the shaped - // buffer matches that of the executor. - std::unique_ptr clone = - CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *clone)); - } + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + execute_backend_->transfer_manager()->AllocateShapedBuffer( + shape, execute_backend_->memory_allocator(), + executor->device_ordinal())); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, *shaped_buffer)); + replicated_buffers.emplace_back(std::move(shaped_buffer)); } - TF_ASSIGN_OR_RETURN( - *result->mutable_data(), - allocation_tracker_.Register(std::move(shaped_buffer), - StrCat("TransferToServer literal of shape ", - ShapeUtil::HumanString(shape)))); + TF_ASSIGN_OR_RETURN(*result->mutable_data(), + allocation_tracker_.RegisterReplicatedBuffers( + std::move(replicated_buffers), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); return tensorflow::Status::OK(); } @@ -1254,7 +1523,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, CreateModuleConfig(program_shape, {}, execution_options, - *user_computation)); + user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( @@ -1275,6 +1544,50 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. + // + // TODO(b/77824332): Make HloEvaluator take care of the re-layout. + if (arg->has_output_layout()) { + result_literal = result_literal->Relayout(arg->output_layout()); + } + *result->mutable_literal() = result_literal->ToProto(); + + return tensorflow::Status::OK(); +} + +tensorflow::Status Service::ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { + if (!arg->has_computation()) { + return InvalidArgument("computations may not be empty"); + } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("program shape may not be empty"); + } + if (arg->computation().program_shape().parameters_size() != 0) { + return InvalidArgument( + "constant computation may not depend on any parameters."); + } + + ProgramShape program_shape = arg->computation().program_shape(); + TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); + if (arg->has_output_layout()) { + TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( + arg->output_layout(), program_shape.result())); + } + + HloModuleConfig config(program_shape); + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); + + HloEvaluator evaluator; + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate>( + *module, /*arg_literals=*/{})); + + // Since the result layout is non-effective to the Evaluator results, explicit + // relayout here. + // + // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { result_literal = result_literal->Relayout(arg->output_layout()); } @@ -1286,7 +1599,7 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, tensorflow::Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, - allocation_tracker_.Resolve(arg->data())); + allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); return tensorflow::Status::OK(); } @@ -1346,6 +1659,36 @@ tensorflow::Status Service::GetComputationStats( return tensorflow::Status::OK(); } +tensorflow::Status Service::GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { + if (!arg->has_computation()) { + return InvalidArgument("Computations may not be empty."); + } + if (!arg->computation().has_program_shape()) { + return InvalidArgument("Program shape may not be empty."); + } + + HloModuleConfig config(arg->computation().program_shape()); + config.set_debug_options(arg->debug_options()); + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + HloModule::CreateFromProto(arg->computation(), config)); + + hlo_graph_dumper::MaybeDumpHloModule(*module, + "computation statistics subject"); + + // Run HLO analysis to get the computation statistics. + HloCostAnalysis analysis( + execute_backend_->compiler()->ShapeSizeBytesFunction()); + + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); + + ComputationStats stats; + stats.set_flop_count(analysis.flop_count()); + stats.set_transcendental_count(analysis.transcendental_count()); + *result->mutable_stats() = stats; + return tensorflow::Status::OK(); +} + template tensorflow::Status Service::AddInstruction( const RequestT* arg, ResponseT* result, @@ -1445,6 +1788,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kFftRequest: handle_status = computation->AddFftInstruction(arg->fft_request()); break; + case OpRequest::kGatherRequest: + handle_status = computation->AddGatherInstruction(arg->gather_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); @@ -1456,6 +1802,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddOutfeedInstruction(arg->outfeed_request()); break; + case OpRequest::kHostComputeRequest: + handle_status = + computation->AddHostComputeInstruction(arg->host_compute_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, @@ -1548,8 +1898,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kSendRequest: { TF_RETURN_IF_ERROR( channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - TF_RETURN_IF_ERROR(computation->AddSendInstruction(arg->send_request())); - return tensorflow::Status::OK(); + // Send does not return a value, but we need a handle to be able to + // set OpMetadata and OpSharding (device assignment). + handle_status = computation->AddSendInstruction(arg->send_request()); + break; } case OpRequest::kRecvRequest: { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 6ce241971156599aaa25aea1b0caac0e1bd5379c..e399f1ac1904f8d6145f43b0ed12d8018765d9a1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -112,12 +112,29 @@ class Service : public ServiceInterface { tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; + // Executes a computation with the provided global data passed as + // immutable arguments. The request contains the whole computation graph. + // Returns global data output and execution timing. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; + // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) override; + // Executes one or more computations in parallel with the provided global data + // passed as immutable arguments. Returns global data output for each + // computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + tensorflow::Status ExecuteGraphParallel( + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; + // Requests one or more device handles from the target. // // When N device handles are requested and the number of replicas is R, at @@ -189,6 +206,9 @@ class Service : public ServiceInterface { // Computes the value of a constant expression. tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, ComputeConstantResponse* result) override; + tensorflow::Status ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. @@ -216,6 +236,13 @@ class Service : public ServiceInterface { const ComputationStatsRequest* arg, ComputationStatsResponse* result) override; + // Retrieves the statistics of a computation. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + tensorflow::Status GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; + // Snapshots the current state of a computation handle into a serializable // protocol buffer form, so it can be loaded via // LoadComputationSnapshot. @@ -252,7 +279,21 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); + + // Picks a parallel response and fills the result. + Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, + ExecuteResponse* result); + + // Prepare the executors for executing parallel. + StatusOr> GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const; + + // Prepare the arguments for executing parallel. + StatusOr>> GetArguments( + const ExecutionOptions& execution_options, + tensorflow::gtl::ArraySlice arguments); protected: friend class LocalExecutable; @@ -262,14 +303,15 @@ class Service : public ServiceInterface { Service(const ServiceOptions& options, std::unique_ptr execute_backend); - static StatusOr> CreateComputeConstantBackend(); - // Resolves the given argument handles in the allocation tracker and returns - // the corresponding allocations. The function also verifies that each - // allocation matches the execution platform and device ordinal. - StatusOr> ResolveAndValidateArguments( + // the corresponding allocations for every replica. The function also verifies + // that each allocation matches the execution platform and device ordinal of + // the corresponding replica. + StatusOr>> + ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - int device_ordinal); + tensorflow::gtl::ArraySlice + stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -277,7 +319,7 @@ class Service : public ServiceInterface { const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options, - const UserComputation& user_computation); + const UserComputation* user_computation = nullptr); // Builds an Executable for the given parameters. // @@ -290,6 +332,15 @@ class Service : public ServiceInterface { perftools::gputools::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); + // Builds an Executable for the given HLO module proto. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr> BuildExecutable( + const HloModuleProto& module_proto, + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator = nullptr); + // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. StatusOr>> BuildExecutables( @@ -298,6 +349,12 @@ class Service : public ServiceInterface { Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); + StatusOr>> BuildExecutables( + const std::vector& module_protos, + std::vector> module_configs, + Backend* backend, + std::vector> executors, + DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and @@ -314,16 +371,17 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - const string& result_tag, ExecutionProfile* profile); + const tensorflow::gtl::ArraySlice> + arguments, + Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice> arguments, + tensorflow::gtl::ArraySlice>> + arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, @@ -336,6 +394,14 @@ class Service : public ServiceInterface { const std::function(UserComputation*)>& adder); + // Executes a single computation which has more than one target device. + // The N devices are expected to all return an empty tuple, but one, which + // will be the result of this computation. + tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result); + tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result); + // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. @@ -372,8 +438,6 @@ class Service : public ServiceInterface { CompilationCache compilation_cache_; // Backend to compile and execute computations on. - // - // TODO(b/28616830): Support multiple backends for execution. std::unique_ptr execute_backend_; TF_DISALLOW_COPY_AND_ASSIGN(Service); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 004889b5f216015ee1e1308702b2bf4cb0deb344..77e12d36024dae56003ad4e59b54f9934dfc2c58 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -169,11 +169,11 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s. Got: %s", + return InvalidArgument("Expected non-tuple argument for %s, but got %s.", op_type.ToString().c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s. Got: %s", + return InvalidArgument("Expected non-opaque argument for %s, but got %s.", op_type.ToString().c_str(), ShapeUtil::HumanString(shape).c_str()); } else { @@ -193,8 +193,10 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, const Shape& accumulator_shape = reducer_shape.result(); if (ShapeUtil::Rank(accumulator_shape) != 0) { - return Unimplemented( - "Reduction function currently must have rank-0 result."); + return InvalidArgument( + "Reduction function must have rank 0 (rank %lld reduction function " + "given).", + ShapeUtil::Rank(accumulator_shape)); } // Check that the accumulator can be passed in as the first argument. @@ -235,8 +237,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, reducer_shape.parameters(1))) { return InvalidArgument( - "Reduction function's second parameter shape currently must " - "match the result shape. Got %s vs %s", + "Reduction function's second parameter shape must " + "match the result shape, but got %s vs %s.", ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(), ShapeUtil::HumanString(accumulator_shape).c_str()); } @@ -258,29 +260,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (int64 i = 0; i < window.dimensions_size(); ++i) { const auto& dim = window.dimensions(i); if (dim.size() <= 0) { - return InvalidArgument("Window has a non-positive dimension. Window: %s", + return InvalidArgument("Window %s has a non-positive dimension.", window.DebugString().c_str()); } if (dim.stride() <= 0) { - return InvalidArgument("Window has a non-positive stride. Window: %s", + return InvalidArgument("Window %s has a non-positive stride.", window.DebugString().c_str()); } if (!allow_negative_padding && dim.padding_low() < 0) { - return InvalidArgument("Window has a negative low padding. Window: %s", + return InvalidArgument("Window %s has a negative low padding.", window.DebugString().c_str()); } if (!allow_negative_padding && dim.padding_high() < 0) { - return InvalidArgument("Window has a negative high padding. Window: %s", + return InvalidArgument("Window %s has a negative high padding.", window.DebugString().c_str()); } if (dim.base_dilation() < 1) { return InvalidArgument( - "Window has a non-positive base area dilation factor. Window: %s", + "Window %s has a non-positive base area dilation factor.", window.DebugString().c_str()); } if (dim.window_dilation() < 1) { return InvalidArgument( - "Window has a non-positive window dilation factor. Window: %s", + "Window %s has a non-positive window dilation factor.", window.DebugString().c_str()); } @@ -302,12 +304,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const HloInstruction* operand) { + return InferUnaryOpShape(opcode, operand->shape()); +} + +/* static */ StatusOr ShapeInference::InferUnaryOpShape( + HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. if (opcode == HloOpcode::kCopy) { - return operand->shape(); + return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape()); + return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); } /* static */ StatusOr ShapeInference::InferUnaryOpShape( @@ -320,8 +327,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_CEIL: if (!ShapeUtil::ElementIsFloating(arg)) { return InvalidArgument( - "expected element type in shape to be floating for floor/ceil " - "operation; got %s", + "Expected element type in shape to be floating for floor/ceil " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -333,8 +340,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(arg) && !ShapeUtil::ElementIsComplex(arg)) { return InvalidArgument( - "expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s", + "Expected element type in shape to be floating or complex for " + "sin/cos/exp/log/tanh operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -342,8 +349,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_IMAG: if (!ShapeUtil::ElementIsComplex(arg)) { return InvalidArgument( - "expected element type in shape to be complex for real/imag " - "operation; got %s", + "Expected element type in shape to be complex for real/imag " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return ShapeUtil::ChangeElementType(arg, F32); @@ -363,8 +370,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (arg.element_type() != PRED && !primitive_util::IsIntegralType(arg.element_type())) { return InvalidArgument( - "expected pred or an integral element type in argument to not " - "operation; got %s", + "Expected pred or an integral element type in argument to Not " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -372,8 +379,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_IS_FINITE: if (!ShapeUtil::ElementIsFloating(arg)) { return InvalidArgument( - "expected element type in shape to be floating point for IsFinite " - "operation; got %s", + "Expected element type in shape to be floating point for IsFinite " + "operation; got %s.", PrimitiveType_Name(arg.element_type()).c_str()); } return ShapeUtil::ChangeElementType(arg, PRED); @@ -389,10 +396,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, tensorflow::gtl::ArraySlice arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { - return InvalidArgument("Concatenate expects at least one argument"); + return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("dimension to concatenate along out of bounds: %lld", + return InvalidArgument("Concatenate dimension out of bounds: %lld.", dimension); } const Shape* arg_shape = nullptr; @@ -408,14 +415,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " - "(%s)", + "(%s).", ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( - "cannot concatenate arrays with different element types: %s vs %s", + "Cannot concatenate arrays with different element types: %s vs %s.", PrimitiveType_Name(arg_shape->element_type()).c_str(), PrimitiveType_Name(shape->element_type()).c_str()); } @@ -428,9 +435,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // concatenating. } return InvalidArgument( - "cannot concatenate arrays that differ in dimensions other than " + "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld", + "the same): %s vs %s in dimension %lld.", ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::HumanString(*shape).c_str(), dimension); } @@ -452,7 +459,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) && !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( - "Unsupported conversion from complex to real type: %s => %s", + "Conversion from complex to real type %s => %s is not implemented.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -461,7 +468,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "cannot convert from or to tuple type; requested conversion: %s => %s", + "Convert does not allow tuples, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -474,24 +481,23 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, auto old_element_type = operand_shape.element_type(); if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { - return Unimplemented( - "Unsupported conversion between real and complex types: %s => %s", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + return InvalidArgument("Conversion from complex to real type %s => %s.", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); } if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "cannot convert from or to tuple type; requested conversion: %s => %s", + "Cannot convert from or to tuple type; requested conversion: %s => %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( - "cannot bitcast types with different bit-widths: %s => %s", + "Cannot bitcast types with different bit-widths: %s => %s.", PrimitiveType_Name(old_element_type).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -504,20 +510,20 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const int mantissa_bits) { if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( - "expected element type in shape to be floating point for " - "ReducePrecision operation; got %s", + "Expected element type in shape to be floating point for " + "ReducePrecision operation; got %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having // no exponent bits doesn't produce a sensible number, so we require at // least one. - return InvalidArgument("expected exponent_bits >= 1; got %d", + return InvalidArgument("Expected exponent_bits >= 1; got %d.", exponent_bits); } if (mantissa_bits < 0) { // A number with no mantissa bits is still meaningful, however. - return InvalidArgument("expected non-negative mantissa_bits; got %d", + return InvalidArgument("Expected non-negative mantissa_bits; got %d.", mantissa_bits); } return operand_shape; @@ -528,23 +534,23 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const PaddingConfig& padding_config) { if (ShapeUtil::IsTuple(operand_shape)) { return InvalidArgument( - "pad operation does not support tuple-shape operands"); + "Pad operation does not support tuple-shape operands."); } if (!ShapeUtil::IsScalar(padding_value_shape)) { return InvalidArgument( - "pad operation does not support non-scalar padding values"); + "Pad operation does not support non-scalar padding values."); } if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { return InvalidArgument( "The rank of the operand and the padding configuration do not match: " - "%s vs %s", + "%s vs %s.", ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( - "the element types of the operands to pad do not match"); + "The element types of the operands to Pad do not match."); } std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { @@ -605,7 +611,7 @@ Status ValidateDotDimensionNumbers( lhs_batch_dimensions) || !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { - return InvalidArgument("A dimension number is out of range in dot: %s", + return InvalidArgument("A dimension number is out of range in Dot: %s.", dimension_numbers.DebugString().c_str()); } @@ -623,7 +629,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { - return InvalidArgument("A dimension number is not unique in dot: %s", + return InvalidArgument("A dimension number is not unique in Dot: %s.", dimension_numbers.DebugString().c_str()); } @@ -641,8 +647,7 @@ Status ValidateDotDimensionNumbers( rhs_non_contracting_non_batch_dims < 0 || rhs_non_contracting_non_batch_dims > 1) { return InvalidArgument( - "batch and contracting dimension number mismatch " - "with rank "); + "Batch and contracting dimension number mismatch with rank."); } // Check that batch dimension numbers are ordered before all others, and @@ -654,7 +659,7 @@ Status ValidateDotDimensionNumbers( !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), rhs_batch_dimensions.begin())) { return InvalidArgument( - "batch dimension numbers must precede non-batch dimensions and be" + "Batch dimension numbers must precede non-batch dimensions and be" "monotonically increasing."); } @@ -671,22 +676,22 @@ Status ValidateDotDimensionNumbers( auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( - "cannot infer shape for dot operation: %s %s", + "Cannot infer shape for dot operation: %s %s.", ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); if (!addendum.empty()) { - message += ": " + addendum; + message += " " + addendum; } return InvalidArgument("%s", message.c_str()); }; // Check if both element types are the same. if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return fail("element types do not match"); + return fail("Element types do not match."); } if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { - return fail("dot only supports rank 1 or above."); + return fail("Dot only supports rank 1 or above."); } // Validate basic properties of dot dimension numbers. @@ -696,7 +701,7 @@ Status ValidateDotDimensionNumbers( if (dimension_numbers.lhs_contracting_dimensions_size() != dimension_numbers.rhs_contracting_dimensions_size() || dimension_numbers.lhs_contracting_dimensions_size() != 1) { - return fail("must specify one contracting dimension for both lhs and rhs."); + return fail("Must specify one contracting dimension for both lhs and rhs."); } // Check that contracting dimension sizes match. @@ -706,13 +711,13 @@ Status ValidateDotDimensionNumbers( dimension_numbers.rhs_contracting_dimensions(0); if (lhs.dimensions(lhs_contracting_dimension) != rhs.dimensions(rhs_contracting_dimension)) { - return fail("contracting dimension sizes do not match."); + return fail("Contracting dimension sizes do not match."); } // Check that number of batch dimensions match. if (dimension_numbers.lhs_batch_dimensions_size() != dimension_numbers.rhs_batch_dimensions_size()) { - return fail("must the same number of batch dimensions for lhs and rhs."); + return fail("Must the same number of batch dimensions for lhs and rhs."); } // Check that batch dimension numbers and sizes match. @@ -721,7 +726,7 @@ Status ValidateDotDimensionNumbers( dimension_numbers.rhs_batch_dimensions(i) || lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { - return fail("batch dimension numbers and sizes must match for lhs/rhs."); + return fail("Batch dimension numbers and sizes must match for lhs/rhs."); } } @@ -770,10 +775,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else if (rhs.dimensions(i) == 1) { output_dimensions[i] = lhs.dimensions(i); } else { - return InvalidArgument("binary op %s with incompatible shapes: %s and %s", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + return InvalidArgument( + "Binary op %s with incompatible shapes: %s and %s.", + BinaryOperation_Name(operation).c_str(), + ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str()); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -788,15 +794,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. - return InvalidArgument("automatic shape inference not supported: %s and %s", + return InvalidArgument("Automatic shape inference not supported: %s and %s", ShapeUtil::HumanString(smaller_shape).c_str(), ShapeUtil::HumanString(larger_shape).c_str()); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( - "size of broadcast_dimensions has to match lower-rank operand's " + "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu", + "%zu.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -846,13 +852,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "broadcast dimension number (%lld) cannot be negative", + "Broadcast dimension number (%lld) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "broadcast dimension number (%lld) too large; higher-rank " - "operand has rank %d", + "Broadcast dimension number (%lld) too large; higher-rank " + "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } int64 small_dimension_size = smaller_shape.dimensions(i); @@ -863,7 +869,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i, + "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, small_dimension_size, large_dimension_size, ShapeUtil::HumanString(smaller_shape).c_str(), ShapeUtil::HumanString(larger_shape).c_str()); @@ -872,7 +878,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "broadcast dimensions order is wrong: %lld comes after %lld", + "Broadcast dimensions order is wrong: %lld comes after %lld.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -892,7 +898,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( - "binary op %s with different element types: %s and %s", + "Binary op %s with different element types: %s and %s.", BinaryOperation_Name(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); @@ -904,8 +910,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!broadcast_dimensions.empty() && broadcast_dimensions != identity_dims) { return InvalidArgument( - "broadcast dimensions field must either be not set or be the " - "identity on binary operations with operands of the same rank"); + "Broadcast dimensions field must either be not set or be the " + "identity on binary operations with operands of the same rank."); } } @@ -943,6 +949,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( rhs->shape(), /*broadcast_dimensions=*/{}); } +/* static */ StatusOr ShapeInference::InferBinaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, + broadcast_dimensions); +} + /* static */ StatusOr ShapeInference::InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -979,8 +992,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_COMPLEX: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( - "expected element type in shape to be floating for complex compose " - "operation; got %s", + "Expected element type in shape to be floating for complex compose " + "operation; got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, @@ -989,7 +1002,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { - return Unimplemented("complex component type not supported"); + return Unimplemented("Complex component type is not implemented."); } } case BINOP_AND: @@ -997,8 +1010,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( - "expected pred or integral type in argument to and/or operation; " - "got %s", + "Expected pred or integral type in argument to and/or operation; " + "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } return InferElementwiseBinaryOpShape(operation, lhs, rhs, @@ -1016,7 +1029,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } default: return Unimplemented( - "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s", + "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", BinaryOperation_Name(operation).c_str(), lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); } @@ -1025,8 +1038,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, const HloInstruction* ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(), - rhs->shape(), ehs->shape()); + return InferTernaryOpShape(opcode, lhs->shape(), rhs->shape(), ehs->shape()); +} + +/* static */ StatusOr ShapeInference::InferTernaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { + return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); } /* static */ StatusOr ShapeInference::InferTernaryOpShape( @@ -1041,7 +1058,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("unknown operation %s", + return InvalidArgument("Unknown operation %s.", TernaryOperation_Name(operation).c_str()); } } @@ -1053,6 +1070,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } + return InferVariadicOpShape(opcode, operand_shapes); +} + +/* static */ StatusOr ShapeInference::InferVariadicOpShape( + HloOpcode opcode, + tensorflow::gtl::ArraySlice operand_shapes) { return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), operand_shapes); } @@ -1072,7 +1095,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return result; } default: - return InvalidArgument("unknown operation %s", + return InvalidArgument("Unknown operation %s.", VariadicOperation_Name(operation).c_str()); } } @@ -1082,7 +1105,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const ProgramShape& to_apply, tensorflow::gtl::ArraySlice dimensions) { if (arg_shapes.empty()) { - return InvalidArgument("Map expects at least one argument"); + return InvalidArgument("Map expects at least one argument."); } // All arguments must have the same shape. @@ -1113,7 +1136,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } return InvalidArgument( "Map operation requires all operands to have the same shape; got: " - "%s", + "%s.", Join(pieces, ", ").c_str()); } @@ -1122,7 +1145,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu", + "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1130,7 +1153,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int i = 0; i < dimensions.size(); ++i) { if (dimensions[i] != i) { return InvalidArgument( - "Map requires monotonically increasing dimension numbers, found: %s ", + "Map requires monotonically increasing dimension numbers; got: %s.", Join(dimensions, ", ").c_str()); } } @@ -1139,7 +1162,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu", + "arity: %d, arguments: %zu.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1147,8 +1170,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& output_shape = to_apply.result(); if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( - "mapped computation's result has to be a scalar; " - "got: %s", + "Mapped computation's result has to be a scalar; got: %s.", ShapeUtil::HumanString(output_shape).c_str()); } @@ -1157,16 +1179,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::IsScalar(parameter_shape)) { return InvalidArgument( - "mapped computation's parameter has to be a scalar; " - "got parameter %d shape: %s", + "Mapped computation's parameter has to be a scalar; " + "got parameter %d shape: %s.", i, ShapeUtil::HumanString(parameter_shape).c_str()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, *arg_shape)) { return InvalidArgument( - "mapped computation's parameter type has to match argument element " - "type; got parameter %d shape: %s, argument shape: %s", + "Mapped computation's parameter type has to match argument element " + "type; got parameter %d shape: %s, argument shape: %s.", i, ShapeUtil::HumanString(parameter_shape).c_str(), ShapeUtil::HumanString(*arg_shape).c_str()); } @@ -1197,21 +1219,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld", + "be a non-negative number, got %lld.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld", + "batch-norm-training to be at least 1; got %lld.", ShapeUtil::Rank(operand_shape)); } @@ -1232,7 +1254,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-training must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1241,7 +1263,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(offset_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1251,7 +1273,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1264,7 +1286,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of offset factor should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } @@ -1272,7 +1294,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1307,21 +1329,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld", + "be a non-negative number, got %lld.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld", + "batch-norm-inference to be at least 1; got %lld.", ShapeUtil::Rank(operand_shape)); } @@ -1342,7 +1364,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-inference must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1352,7 +1374,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of offset factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(offset_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1363,7 +1385,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of scale factor is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1374,7 +1396,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of mean is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1385,7 +1407,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "The inputs should have the same element type for " "batch-norm-inference, " "but the shape of variance is %s " - "and the shape of operand is %s", + "and the shape of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(variance_shape.element_type()).c_str()); } @@ -1398,7 +1420,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of offset factor should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } @@ -1406,7 +1428,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1414,7 +1436,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of mean should be the same as feature count," "but the size of mean is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } @@ -1422,7 +1444,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of variance should be the same as feature count," "but the size of variance is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1455,7 +1477,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld", + "got feature_index %lld, and rank %lld.", feature_index, ShapeUtil::Rank(operand_shape)); } @@ -1463,7 +1485,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld", + " rank(output_grad_shape) %lld.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } @@ -1491,14 +1513,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::ElementIsFloating(operand_shape)) { return InvalidArgument( "The operand to batch-norm-grad must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(operand_shape.element_type()).c_str()); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " - "element type, but the shape is %s", + "element type, but the shape is %s.", PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } @@ -1507,7 +1529,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(output_grad_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1517,7 +1539,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(scale_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1527,7 +1549,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1537,7 +1559,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " - "and the element type of operand is %s", + "and the element type of operand is %s.", PrimitiveType_Name(mean_shape.element_type()).c_str(), PrimitiveType_Name(operand_shape.element_type()).c_str()); } @@ -1551,7 +1573,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of mean should be the same as feature count," "but the size of offset factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } @@ -1559,7 +1581,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of scale factor should be the same as feature count," "but the size of scale factor is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1567,7 +1589,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The size of variance should be the same as feature count," "but the size of variance is %lld " - "and the feature count is %lld", + "and the feature count is %lld.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1578,7 +1600,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld", + "and the bound of output_grad_shape is %lld.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1596,7 +1618,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( - "Convolution with different element types: %s and %s", + "Convolution with different element types: %s and %s.", ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -1612,21 +1634,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" - "Window: %s\nDimension numbers: %s", + "Window: %s\nDimension numbers: %s.", window.DebugString().c_str(), dnums.DebugString().c_str()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( - "The LHS argument to a convolution should have rank %d.\n" - "lhs: %s", + "The LHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs).c_str()); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( - "The RHS argument to a convolution should have rank %d.\n" - "lhs: %s", + "The RHS argument to a convolution should have rank %d; lhs: %s.", num_dims, ShapeUtil::HumanString(lhs).c_str()); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); @@ -1663,26 +1683,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( - "A dimension number is out of range in convolution: %s", + "A dimension number is out of range in convolution: %s.", dnums.DebugString().c_str()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " - "once: %s", + "once: %s.", dnums.DebugString().c_str()); } @@ -1706,7 +1726,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Expected LHS feature dimension (value %lld) to match RHS " "input feature dimension (value %lld); got (%s, %s)\n" - "Dimension numbers: {%s}", + "Dimension numbers: {%s}.", input_features, kernel_input_features, ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); @@ -1720,7 +1740,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "Window dimensions do not match RHS shape:\n\t" "RHS shape: %s\n\t" "Window: {%s}\n\t" - "Dimension numbers: {%s}", + "Dimension numbers: {%s}.", ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), dnums.ShortDebugString().c_str()); } @@ -1748,8 +1768,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const tensorflow::gtl::ArraySlice fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3, but got %lld", - fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); } #define RET_CHECK_RANK(x) \ if (x.dimensions_size() < fft_rank) { \ @@ -1762,7 +1781,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case FFT: case IFFT: if (in.element_type() != C64) { - return InvalidArgument("%s requires C64 input type, found %s", + return InvalidArgument("%s requires C64 input type, found %s.", FftType_Name(fft_type).c_str(), PrimitiveType_Name(in.element_type()).c_str()); } @@ -1770,7 +1789,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return in; case RFFT: { if (in.element_type() != F32) { - return InvalidArgument("RFFT requires F32 input type, found %s", + return InvalidArgument("RFFT requires F32 input type, found %s.", PrimitiveType_Name(in.element_type()).c_str()); } RET_CHECK_RANK(in); @@ -1779,7 +1798,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld", + "dimension %lld is %lld and should be %lld.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1792,7 +1811,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } case IRFFT: { if (in.element_type() != C64) { - return InvalidArgument("IRFFT requires C64 input type, found %s", + return InvalidArgument("IRFFT requires C64 input type, found %s.", PrimitiveType_Name(in.element_type()).c_str()); } RET_CHECK_RANK(in); @@ -1802,7 +1821,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld", + "fft_length, but dimension %lld is %lld and should be %lld.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1812,7 +1831,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld", + "dimension %d is %lld and should be %lld.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1850,8 +1869,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { return InvalidArgument( - "attempting to reduce out-of-bounds dimension %lld in shape %s", - dimension, ShapeUtil::HumanString(arg).c_str()); + "Reducing out-of-bounds dimension %lld in shape %s.", dimension, + ShapeUtil::HumanString(arg).c_str()); } } TF_RETURN_IF_ERROR( @@ -1891,30 +1910,30 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { return InvalidArgument( - "select function must take 2 parameters, but " + "Select function must take 2 parameters, but " "takes %d parameter(s).", select_shape.parameters_size()); } const Shape& select_result_shape = select_shape.result(); if (!ShapeUtil::Compatible(select_result_shape, ShapeUtil::MakeShape(PRED, {}))) { - return Unimplemented("select function must have rank-0 PRED result."); + return InvalidArgument("Select function must have rank-0 PRED result."); } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(0))) { return InvalidArgument( - "select function's first parameter shape currently must " - "match the operand element shape. Got %s vs %s", + "Select function's first parameter shape currently must " + "match the operand element shape, but got %s vs %s.", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( - "select function's second parameter shape currently must " - "match the operand element shape. Got %s vs %s", + "Select function's second parameter shape currently must " + "match the operand element shape, but got %s vs %s.", ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } @@ -1931,8 +1950,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, window_result_shape)) { return InvalidArgument( - "source shape does not match the shape of window-reduced operand: " - "source(%s), window-reduced operand(%s)", + "Source shape does not match the shape of window-reduced operand: " + "source(%s), window-reduced operand(%s).", ShapeUtil::HumanString(source_shape).c_str(), ShapeUtil::HumanString(window_result_shape).c_str()); } @@ -1946,7 +1965,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " - "{%s}; strides: {%s}", + "{%s}; strides: {%s}.", message.c_str(), ShapeUtil::HumanString(arg).c_str(), Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); @@ -1969,7 +1988,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "slice index count does not match argument rank: %zu vs %lld", + "Slice index count does not match argument rank: %zu vs %lld.", starts.size(), ShapeUtil::Rank(arg)); } @@ -1979,7 +1998,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("negative start index to slice: %lld", + return InvalidArgument("Negative start index to slice: %lld.", start_index); } if (limit_index > arg.dimensions(dimension)) { @@ -1999,7 +2018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("stride (%lld) must be positive", stride); + return InvalidArgument("Stride (%lld) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2023,20 +2042,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %lld must be rank1.", ShapeUtil::Rank(start_indices_shape)); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( - "dynamic slice start indices must be of integral type."); + "Dynamic slice start indices must be of integral type."); } const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s)", + "Dynamic slice start number of dimensions %lld (%s) must match rank " + "%lld of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape).c_str()); @@ -2044,7 +2063,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "dynamic slice index count does not match argument rank: %zu vs %lld", + "Dynamic slice index count does not match argument rank: %zu vs %lld.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2052,12 +2071,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("negative size index to dynamic slice: %lld", + return InvalidArgument("Negative size index to dynamic slice: %lld.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "slice dim size %lld greater than dynamic slice dimension: %lld", + "Slice dim size %lld greater than dynamic slice dimension: %lld.", slice_dim_size, input_dim_size); } VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, @@ -2086,20 +2105,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %lld must be rank1.", ShapeUtil::Rank(start_indices_shape)); } if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( - "dynamic update slice start indices must be of integral type."); + "Dynamic update slice start indices must be of integral type."); } const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s)", + "Dynamic update slice start number of dimensions %lld (%s) must match " + "rank %lld of slice input (%s).", start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape).c_str()); @@ -2107,16 +2126,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "dynamic update slice update rank does not match argument rank: " - "%lld vs %lld", + "Dynamic update slice update rank does not match argument rank: " + "%lld vs %lld.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, update_shape)) { return InvalidArgument( - "dynamic update slice update element type does not match argument. " - "operand.element_type: %s vs update.element_type: %s", + "Dynamic update slice update element type does not match argument. " + "operand.element_type: %s vs update.element_type: %s.", PrimitiveType_Name(operand_shape.element_type()).c_str(), PrimitiveType_Name(update_shape.element_type()).c_str()); } @@ -2126,12 +2145,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "size index %lld to dynamic update slice must be >= 0", + "Size index %lld to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "update dim size %lld greater than dynamic slice dimension: %lld", + "Update dim size %lld greater than dynamic slice dimension: %lld.", update_dim_size, input_dim_size); } VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, @@ -2151,7 +2170,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "one of the reverse dimensions (%lld) is out-of-bounds in shape %s", + "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", dimension, ShapeUtil::HumanString(operand_shape).c_str()); } } @@ -2162,14 +2181,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& arg, int64 index) { if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( - "cannot infer shape: attempting to index into non-tuple: %s", + "Cannot infer shape: attempting to index into non-tuple: %s.", ShapeUtil::HumanString(arg).c_str()); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "cannot infer shape: attempt to index out of tuple bounds: %lld " - ">= %d in shape %s", + "Cannot infer shape: attempt to index out of tuple bounds: %lld " + ">= %d in shape %s.", index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); } @@ -2181,17 +2200,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& init) { // Check the number of parameters for given computations. if (condition.parameters_size() != 1) { - return InvalidArgument("condition must take 1 arguments; got %d", + return InvalidArgument("Condition must take 1 arguments; got %d.", condition.parameters_size()); } if (body.parameters_size() != 1) { - return InvalidArgument("body must take 1 arguments; got %d", + return InvalidArgument("Body must take 1 arguments; got %d.", body.parameters_size()); } auto shape_string = [&]() { return tensorflow::strings::Printf( - "condition: %s; body: %s; init: %s", + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition).c_str(), ShapeUtil::HumanString(body).c_str(), ShapeUtil::HumanString(init).c_str()); @@ -2199,15 +2218,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { - return InvalidArgument("condition must return a boolean; got %s", + return InvalidArgument("Condition must return a boolean; got %s.", shape_string().c_str()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || !ShapeUtil::Compatible(body.result(), init)) { return InvalidArgument( - "the parameter of condition and body, the result of the body, and init " - "must all have the same shape; got %s", + "The parameter of condition and body, the result of the body, and init " + "must all have the same shape; got %s.", shape_string().c_str()); } @@ -2219,7 +2238,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& false_operand, const ProgramShape& true_computation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { - return InvalidArgument("predicate must be a boolean; got %s.", + return InvalidArgument("Predicate must be a boolean; got %s.", ShapeUtil::HumanString(predicate).c_str()); } @@ -2302,8 +2321,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s)", + "Reshape operation has mismatched element counts: from=%lld (%s) " + "to=%lld (%s).", ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::ElementsIn(inferred_shape), ShapeUtil::HumanString(inferred_shape).c_str()); @@ -2351,7 +2370,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { - return InvalidArgument("clamp op with different operand types: %s, %s, %s", + return InvalidArgument("Clamp with different operand types: %s, %s, %s.", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); @@ -2372,7 +2391,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } return Unimplemented( - "not yet implemented: %s, %s %s", min.ShortDebugString().c_str(), + "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); } @@ -2391,25 +2410,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } if (!compatible) { return InvalidArgument( - "operands to select must be the same shape; got %s and %s", + "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), ShapeUtil::HumanString(on_false).c_str()); } if (pred.element_type() != PRED) { return InvalidArgument( - "select's pred operand must have PRED element type; got %s", + "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred).c_str()); } - if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) { + if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + ShapeUtil::Rank(pred) == 0) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. return ShapeUtil::ChangeElementType( on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { - return Unimplemented( - "select operation with non-scalar predicate with dimensionality " - " different from the other operands: %s", + return InvalidArgument( + "Select operation with non-scalar predicate with dimensionality " + " different from the other operands: %s.", ShapeUtil::HumanString(pred).c_str()); } } @@ -2427,7 +2447,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Call applied function arity must match number of arguments; got: " "arity: %d, arguments: %zu; computation signature: %s; argument " - "shapes: [%s]", + "shapes: [%s].", to_apply.parameters_size(), arg_shapes.size(), computation_signature.c_str(), argument_shapes.c_str()); } @@ -2439,7 +2459,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::Compatible(arg_shape, param_shape)) { return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " - "argument shape: %s", + "argument shape: %s.", i, ShapeUtil::HumanString(param_shape).c_str(), ShapeUtil::HumanString(arg_shape).c_str()); } @@ -2448,4 +2468,209 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return to_apply.result(); } +static Status ValidateGatherDimensionNumbers( + const Shape& input_shape, + tensorflow::gtl::ArraySlice gather_indices_shape, + const GatherDimensionNumbers& dim_numbers) { + if (!c_is_sorted(dim_numbers.output_window_dims())) { + return InvalidArgument( + "Output window dimensions in gather op must be ascending; got: %s.", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.output_window_dims()) != + dim_numbers.output_window_dims().end()) { + return InvalidArgument( + "Output window dimensions in gather op must not repeat; got: %s.", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_shape_rank = + output_window_dim_count + gather_indices_shape.size() - 1; + + for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { + int64 window_index = dim_numbers.output_window_dims(i); + if (window_index < 0 || window_index >= output_shape_rank) { + return InvalidArgument( + "Window index %d in gather op is out of bounds; got %lld, but should " + "have been in [0,%lld).", + i, window_index, output_shape_rank); + } + } + + if (dim_numbers.gather_dims_to_operand_dims_size() != + gather_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Gather op has %d elements in gather_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of gather_indices is " + "%lld. These two numbers must be equal.", + dim_numbers.gather_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + gather_indices_shape[dim_numbers.index_vector_dim()]); + } + + for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { + int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); + if (gather_dim_to_input_dim < 0 || + gather_dim_to_input_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + input_shape.dimensions_size(), i, gather_dim_to_input_dim); + } + } + + std::vector sorted_gather_dims_to_operand_dims( + dim_numbers.gather_dims_to_operand_dims().begin(), + dim_numbers.gather_dims_to_operand_dims().end()); + + c_sort(sorted_gather_dims_to_operand_dims); + + if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != + sorted_gather_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + } + + for (int64 elided_dim : dim_numbers.elided_window_dims()) { + if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid elided_window_dims set in gather op; valid range is [0, " + "%d), got: %lld.", + input_shape.dimensions_size(), elided_dim); + } + } + + if (!c_is_sorted(dim_numbers.elided_window_dims())) { + return InvalidArgument( + "elided_window_dims in gather op must be sorted; got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.elided_window_dims()) != + dim_numbers.elided_window_dims().end()) { + return InvalidArgument( + "Repeated dimensions not allowed in elided_window_dims in gather op; " + "got: %s.", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +/*static*/ StatusOr ShapeInference::InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + gather_indices_shape, "gather indices operand of gather op")); + + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + return InvalidArgument( + "Gather indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(gather_indices_shape).c_str()); + } + + // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if + // index_vector_dim is rank(P). The bounds of this expanded shape is + // stored in expanded_gather_indices_shape. + + if (gather_indices_shape.dimensions_size() < + gather_dim_numbers.index_vector_dim() || + gather_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Gather index leaf dimension must be within [0, rank(gather_indices) + " + "1). rank(gather_indices) is %d and gather index leaf dimension is " + "%lld.", + gather_indices_shape.dimensions_size(), + gather_dim_numbers.index_vector_dim()); + } + + std::vector expanded_gather_indices_shape; + expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(expanded_gather_indices_shape)); + if (expanded_gather_indices_shape.size() == + gather_dim_numbers.index_vector_dim()) { + expanded_gather_indices_shape.push_back(1); + } + + TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( + input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + + if (window_bounds.size() != input_shape.dimensions_size()) { + return InvalidArgument( + "Gather op must have one window bound for every input dimension; got: " + "len(window_bounds)=%lu, input_shape.rank=%d.", + window_bounds.size(), input_shape.dimensions_size()); + } + + if (window_bounds.size() != + gather_dim_numbers.output_window_dims_size() + + gather_dim_numbers.elided_window_dims_size()) { + return InvalidArgument( + "All components of the window index in a gather op must either be a " + "output window index or explicitly elided; got len(window_bounds)=%lu, " + "output_window_bounds=%s, elided_window_bounds=%s.", + window_bounds.size(), + Join(gather_dim_numbers.output_window_dims(), ",").c_str(), + Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + } + + for (int i = 0; i < window_bounds.size(); i++) { + int64 window_bound = window_bounds[i]; + int64 corresponding_input_bound = input_shape.dimensions(i); + if (window_bound < 0 || window_bound > corresponding_input_bound) { + return InvalidArgument( + "Window bound at index %d in gather op is out of range, must be " + "within " + "[0, %lld), got %lld.", + i, corresponding_input_bound + 1, window_bound); + } + } + + for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { + if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + return InvalidArgument( + "Gather op can only elide window indices with bound 1, but bound is " + "%lld for index %lld at position %d.", + window_bounds[gather_dim_numbers.elided_window_dims(i)], + gather_dim_numbers.elided_window_dims(i), i); + } + } + + int64 result_rank = gather_dim_numbers.output_window_dims_size() + + (expanded_gather_indices_shape.size() - 1); + int64 window_dims_seen = 0; + int64 gather_dims_seen = 0; + std::vector output_dim_bounds; + output_dim_bounds.reserve(result_rank); + for (int64 i = 0; i < result_rank; i++) { + int64 current_bound; + bool is_window_index = + c_binary_search(gather_dim_numbers.output_window_dims(), i); + if (is_window_index) { + while (c_binary_search(gather_dim_numbers.elided_window_dims(), + window_dims_seen)) { + window_dims_seen++; + } + current_bound = window_bounds[window_dims_seen++]; + } else { + if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { + gather_dims_seen++; + } + current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + } + + output_dim_bounds.push_back(current_bound); + } + + return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index b39151ebbc19f5d0b702a80da5069f58c8dfb07d..9da2c99b4177f08ece8daabaf2922ddd7e947a1b 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -37,12 +37,19 @@ namespace xla { // the expected result type for computations that are built up via the API -- // the shape that results from an operation is inferred. Some methods have // overloads for inferring shape at the HLO level. +// +// TODO(b/73352135): Shape inference does not issue very good error messages, in +// part because HloInstruction::ToString() is not available since shape +// inference runs before the HloInstruction object is created. We need a +// solution for this. class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. static StatusOr InferUnaryOpShape(UnaryOperation operation, const Shape& arg); + static StatusOr InferUnaryOpShape(HloOpcode opcode, + const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, const HloInstruction* operand); @@ -51,6 +58,9 @@ class ShapeInference { static StatusOr InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + static StatusOr InferBinaryOpShape( + HloOpcode opcode, const Shape& lhs, const Shape& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -60,6 +70,9 @@ class ShapeInference { static StatusOr InferTernaryOpShape(TernaryOperation operation, const Shape& lhs, const Shape& rhs, const Shape& ehs); + static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, + const Shape& rhs, + const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, @@ -70,6 +83,9 @@ class ShapeInference { static StatusOr InferVariadicOpShape( VariadicOperation operation, tensorflow::gtl::ArraySlice operand_shapes); + static StatusOr InferVariadicOpShape( + HloOpcode opcode, + tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operands); @@ -248,6 +264,14 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers); + // Helper that infers the shape of the tensor produced by a gather operation + // with the given input shape, gather indices shape and gather dimension + // numbers. + static StatusOr InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 026c021165785bd3945d6a846dae446ad45da9b7..0e61994a786b53a295ef9c9c2287b28fbf754d9b 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -18,15 +18,16 @@ limitations under the License. #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { +using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -134,7 +135,7 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("operands to select must be the same shape")); + HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); @@ -339,7 +340,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("source shape does not match")); + HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { @@ -350,7 +351,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must take 2 parameters")); + HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { @@ -361,7 +362,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function must have rank-0 PRED")); + HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { @@ -372,7 +373,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's first parameter")); + HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { @@ -383,7 +384,7 @@ TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), - HasSubstr("select function's second parameter")); + HasSubstr("Select function's second parameter")); } TEST_F(ShapeInferenceTest, Convolve) { @@ -905,7 +906,7 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("Dot only supports rank")); } // 3D 2D: error @@ -917,7 +918,7 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch and contracting dimension number mismatch")); + HasSubstr("Batch and contracting dimension number mismatch")); } // vector vector -> scalar @@ -1023,7 +1024,7 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("must specify one contracting dimension for both " + HasSubstr("Must specify one contracting dimension for both " "lhs and rhs")); } @@ -1043,7 +1044,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers and sizes must match")); + HasSubstr("Batch dimension numbers and sizes must match")); } // BatchMatMul with different batch dimension numbers fails. @@ -1062,7 +1063,7 @@ TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("batch dimension numbers must precede non-batch")); + HasSubstr("Batch dimension numbers must precede non-batch")); } // BatchMatMul with out-of-range dimension numbers fails. @@ -1165,42 +1166,42 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("automatic")); + HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), - HasSubstr("size of broadcast_dimensions has to match")); + HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), - ContainsRegex("broadcast dimension number .* too large")); + ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array @@ -1209,13 +1210,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), - HasSubstr("broadcast dimensions order is wrong")); + HasSubstr("dimensions order is wrong")); } // Tests for the while instruction with proper shapes. @@ -1241,7 +1242,7 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), - HasSubstr("condition must take 1 arguments")); + HasSubstr("Condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); @@ -1249,14 +1250,14 @@ TEST_F(ShapeInferenceTest, WhileWithBadShapes) { ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("body must take 1 arguments")); + HasSubstr("Body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("condition must return a boolean")); + HasSubstr("Condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = @@ -1300,13 +1301,13 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: -1")); + HasSubstr("dimension out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("dimension to concatenate along out of bounds: 1")); + HasSubstr("dimension out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( @@ -1314,21 +1315,20 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation.")); + HasSubstr("Expected non-tuple argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); - ASSERT_THAT( - inferred_status_error5.status().error_message(), - HasSubstr("cannot concatenate arrays with different element types")); + ASSERT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), - HasSubstr("cannot concatenate arrays that differ in " + HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); } @@ -1466,7 +1466,7 @@ TEST_F(ShapeInferenceTest, Conditional) { ShapeUtil::MakeProgramShape({vector_64_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), - HasSubstr("predicate must be a boolean")); + HasSubstr("Predicate must be a boolean")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, @@ -1527,5 +1527,458 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } +class GatherShapeInferenceTest : public ShapeInferenceTest { + protected: + const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); + const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); + const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); + const Shape s64_4d_tensor_10_9_8_7_1_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); + const Shape s64_4d_tensor_10_9_8_7_5_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape s64_4d_tensor_5_10_9_7_6_ = + ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6}); + const Shape s64_4d_tensor_10_9_5_7_6_ = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + const Shape f32_5d_tensor_50_49_48_47_46_ = + ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); +}; + +TEST_F(GatherShapeInferenceTest, TensorFlowGather) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26})); + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { + // This is equivalent to a dynamic slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3, 4}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { + // The gather indices "tensor" is a scalar S here that's used to slice out + // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0), + /*window_bounds=*/{1, 30, 29, 28, 27})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + tuple_shape_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for input")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, tuple_shape_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for gather indices")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingWindowIndices) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 8, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must be ascending")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowIndices) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must not repeat")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 99, 100, 101}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 2 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 9}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 4 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingElidedWindowDims) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{4}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("All components of the window index in a gather op must either " + "be a output window index or explicitly elided")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 19}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid elided_window_dims set in gather op; valid " + "range is [0, 5), got: 19")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 3}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in elided_window_dims in gather op")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of " + "gather_indices is 5. These two numbers must be equal.")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " + "[0, 5), got: 4->7")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{2, 1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 1, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("elided_window_dims in gather op must be sorted")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{2}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 1, 300, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window bound at index 3 in gather op is out of range, " + "must be within [0, 48), got 300")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Gather op must have one window bound for every input dimension")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 26, 20}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op can only elide window indices with bound 1, " + "but bound is 29 for index 1 at position 0")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/32), + /*window_bounds=*/{30, 29, 28, 27, 26}); + + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather index leaf dimension must be within [0, " + "rank(gather_indices) + 1)")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index fead9b92362bcd1974f2dff6e030bc47dfc5aa85..532f7fd5bfc1dffa86638a6bc51832beebd74e1d 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -226,7 +226,8 @@ StatusOr UserComputation::AddParameterInstruction( return handle; } -Status UserComputation::AddSendInstruction(const SendRequest& send_request) { +StatusOr UserComputation::AddSendInstruction( + const SendRequest& send_request) { tensorflow::mutex_lock lock(mutex_); // Check if the operand of the instruction is valid. @@ -244,7 +245,7 @@ Status UserComputation::AddSendInstruction(const SendRequest& send_request) { VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() << "), data handle " << handle.handle() << ": " << send_request.ShortDebugString(); - return Status::OK(); + return handle; } StatusOr UserComputation::AddRecvInstruction( @@ -315,6 +316,36 @@ StatusOr UserComputation::AddConstantInstruction( return handle; } +StatusOr UserComputation::AddGatherInstruction( + const GatherRequest& gather_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, + LookUpRequest(gather_request.input())); + TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, + LookUpRequest(gather_request.gather_indices())); + + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferGatherShape( + input_request->output_shape(), gather_indices_request->output_shape(), + gather_request.dimension_numbers(), + AsInt64Slice(gather_request.window_bounds()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_gather_request() = gather_request; + + VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << gather_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddGetTupleElementInstruction( const GetTupleElementRequest& get_tuple_element_request) { tensorflow::mutex_lock lock(mutex_); @@ -1253,8 +1284,8 @@ StatusOr UserComputation::AddCustomCallInstruction( TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); } - if (tensorflow::StringPiece(custom_call_request.call_target_name()) - .starts_with("$")) { + if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), + "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", @@ -1276,6 +1307,28 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddHostComputeInstruction( + const HostComputeRequest& host_compute_request) { + tensorflow::mutex_lock lock(mutex_); + + for (const ComputationDataHandle& handle : host_compute_request.operands()) { + TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); + } + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = host_compute_request.shape(); + *request.mutable_request()->mutable_host_compute_request() = + host_compute_request; + + VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << host_compute_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddDotInstruction( const DotRequest& dot_request) { tensorflow::mutex_lock lock(mutex_); @@ -1713,6 +1766,11 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kHostComputeRequest: { + *is_functional = false; + break; + } + case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { @@ -1991,6 +2049,16 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kGatherRequest: { + PureFunctionalVisitor(session_computation, + request.request().gather_request().input(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + request.request().gather_request().gather_indices(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -2643,6 +2711,15 @@ static void ForEachOperand( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& hc_request = + request.request().host_compute_request(); + for (const ComputationDataHandle& operand : hc_request.operands()) { + apply(operand); + } + break; + } + case OpRequest::kDotRequest: { const DotRequest& dot_request = request.request().dot_request(); apply(dot_request.rhs()); @@ -2684,6 +2761,13 @@ static void ForEachOperand( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + apply(gather_request.input()); + apply(gather_request.gather_indices()); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -3231,20 +3315,23 @@ void ComputationLowerer::Visit( HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && + !ShapeUtil::IsTuple(request.output_shape())) { + if (!ShapeUtil::IsTuple(lhs->shape()) && + !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); } - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { + if (!ShapeUtil::IsTuple(rhs->shape()) && + !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { rhs = ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); } - if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { + if (!ShapeUtil::IsTuple(ehs->shape()) && + !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { ehs = ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); } @@ -3299,6 +3386,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& host_compute_request = + request.request().host_compute_request(); + std::vector operands; + for (const ComputationDataHandle& operand : + host_compute_request.operands()) { + operands.push_back(lookup_instruction(operand)); + } + auto output_shape = host_compute_request.shape(); + auto channel_name = host_compute_request.channel_name(); + auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); + hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( + output_shape, operands, channel_name, cost_estimate_ns)); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -3388,7 +3491,6 @@ void ComputationLowerer::Visit( HloInstruction* operand = lookup_instruction(trace_request.operand()); hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(hlo_instruction); break; } @@ -3401,6 +3503,20 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + HloInstruction* input_operand = + lookup_instruction(gather_request.input()); + HloInstruction* gather_indices_operand = + lookup_instruction(gather_request.gather_indices()); + std::vector window_bounds; + c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); + hlo_instruction = add_instruction(HloInstruction::CreateGather( + request.output_shape(), input_operand, gather_indices_operand, + gather_request.dimension_numbers(), window_bounds)); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 54bb24d6d7fe7aa8cc7c684795e40464e4eb6614..5544c868fe905c1ca7e6cab32738440add2e3b4f 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -149,6 +149,10 @@ class UserComputation { StatusOr AddOutfeedInstruction( const OutfeedRequest& outfeed_request); + // Enqueues a host compute instruction onto this user computation. + StatusOr AddHostComputeInstruction( + const HostComputeRequest& host_compute_request); + // Enqueues a call instruction onto this user computation. StatusOr AddCallInstruction( const CallRequest& call_request, @@ -232,12 +236,17 @@ class UserComputation { const UserComputation& false_computation); // Enqueues a Send instruction onto this user computation. - Status AddSendInstruction(const SendRequest& send_request); + StatusOr AddSendInstruction( + const SendRequest& send_request); // Enqueues a Recv instruction onto this user computation. StatusOr AddRecvInstruction( const RecvRequest& recv_request); + // Enqueues a Gather instruction onto this user computation. + StatusOr AddGatherInstruction( + const GatherRequest& gather_request); + // Returns the user-provided name of this user computation, which is provided // via the XLA computation-building API. const string& name() const { return name_; } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index a5f9b01f011ce04f1114c74391a967c62f015221..3ef0cdff6751258e4489ce350deb0931fdf69ef9 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -106,20 +106,12 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kConstant: + case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: + case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; - - case HloOpcode::kTranspose: - return ShapeUtil::TransposeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape(), instruction.dimensions()); - - case HloOpcode::kReshape: - return ShapeUtil::ReshapeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape()); } } diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 981de9b2200a9ae8938db21299580f510834d2f0..ec05a74e286c89dd8db5ae07580e461938d7c087 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -212,7 +213,7 @@ static optional GetLoopTripCount(HloInstruction* while_op) { // Now that we know the index of the induction variable, we can we can try to // compute how many times the loop executes. Start by computing the induction // variable's initial value. - HloEvaluator evaluator; + HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); StatusOr> indvar_init_result = @@ -605,6 +606,78 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { return false; } +static StatusOr TryPropagateConstant(HloInstruction* while_op) { + auto while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + return false; + } + + auto while_body = while_op->while_body(); + auto while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + + auto while_body_param = while_body->parameter_instruction(0); + const HloInstruction::InstructionVector& root_operands = + while_body_root->operands(); + + // Find the loop invariant tuple elements with scalar constant init value and + // build a map from the tuple element index to the constant value. Limit this + // to scalar constant values because propagating array constants can regress + // performance by forcing us to copy constants. + tensorflow::gtl::FlatMap index_to_constant; + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && instr->operand(0) == while_body_param && + ShapeUtil::IsScalar(instr->shape())) { + auto tuple_element = while_init->operand(i); + if (tuple_element->IsConstant()) { + VLOG(3) << "Found loop invariant tuple element " << i << " " + << tuple_element->ToString(); + index_to_constant[i] = tuple_element; + } + } + } + + if (index_to_constant.empty()) { + return false; + } + + // Replace the use of each constant tuple element in the loop_condition and + // loop_body with the corresponding constant value. + auto propagate_constant = [&](HloComputation* computation) -> StatusOr { + HloInstruction* param = computation->parameter_instruction(0); + bool changed = false; + for (auto instr : param->users()) { + // Since only a while-loop with a tuple result reaches here, we can safely + // assume that `param` is a tuple and the first operand of the + // GetTupleElement instruction is a use of `param`. + if (instr->opcode() == HloOpcode::kGetTupleElement) { + VLOG(3) << "tuple index " << instr->tuple_index() << " " + << instr->ToString(); + auto iter = index_to_constant.find(instr->tuple_index()); + if (iter != index_to_constant.end()) { + const HloInstruction* hlo_constant = (*iter).second; + VLOG(3) << "Replace use of " << instr->ToString() << " with " + << hlo_constant->ToString(); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith( + computation->AddInstruction(hlo_constant->Clone()))); + changed = true; + } + } + } + return changed; + }; + + TF_ASSIGN_OR_RETURN(bool changed_cond, + propagate_constant(while_op->while_condition())); + TF_ASSIGN_OR_RETURN(bool changed_body, propagate_constant(while_body)); + + return changed_cond || changed_body; +} + StatusOr WhileLoopSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(3, "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); @@ -635,7 +708,11 @@ StatusOr WhileLoopSimplifier::Run(HloModule* module) { continue; } - StatusOr result = TryRemoveWhileLoop(while_op); + StatusOr result = TryPropagateConstant(while_op); + TF_RETURN_IF_ERROR(result.status()); + changed |= result.ValueOrDie(); + + result = TryRemoveWhileLoop(while_op); TF_RETURN_IF_ERROR(result.status()); if (result.ValueOrDie()) { changed = true; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index d3d55634c97bbdf3f81321d8089bb808c411340b..3d3e1d60f294c3a2574513c1c2f071805a341ad1 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that makes the following transformations on while loops: // // - A while loop with static trip count of 0 is deleted. -// - A while loops with static trip count of 1 is replaced by its body (sans +// - A while loop with static trip count of 1 is replaced by its body (sans // loop). // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index c5183f8d3aee99696ed4114c3f7e451888222137..619e87caa5b6d0f6ec3c3b1489b0d4f50ef29963 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -26,112 +27,140 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { - public: - // Makes a computation that contains a loop that runs num_iters times. - HloComputation* MakeSimpleLoop(int num_iters, HloModule* module); - - // Makes a computation which has one parameter, of the given shape, and always - // returns PRED[]{true}. This is useful as a dummy loop condition. - HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, - HloModule* module); + protected: + // Makes an HloModule that contains a loop with `num_iters` iteration. + void MakeModuleWithSimpleLoop(int num_iters); + + // Similar to MakeModuleWithSimpleLoop except that the loop bound is passed to + // the loop-condition through an element of a tuple which is the + // loop-condition parameter. + void MakeModuleWithSimpleLoopTupleElementLoopBound(int num_iters); }; -HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters, - HloModule* module) { - HloComputation::Builder builder(TestName()); - - auto loop_iter_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - auto loop_data_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1, 2}))); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto loop_var = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(42 + num_iters))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, - limit)); - condition = module->AddEmbeddedComputation(cond_builder.Build()); +void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { + string hlo_string_template = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto loop_var = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto new_loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto loop_data = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - loop_data_init->shape(), loop_var, 1)); - auto new_loop_data = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_data_init->shape(), HloOpcode::kMultiply, loop_data, - loop_data)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); - body = module->AddEmbeddedComputation(body_builder.Build()); + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant({{LOOP_BOUND}}) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - return module->AddEntryComputation(builder.Build()); + string hlo_string = tensorflow::str_util::StringReplace( + hlo_string_template, "{{LOOP_BOUND}}", + tensorflow::strings::StrCat(42 + num_iters), + /*replace_all=*/true); + ParseAndVerifyModule(hlo_string); } -HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation( - const Shape& param_shape, HloModule* module) { - HloComputation::Builder builder(TestName() + ".always_true"); - builder.AddInstruction( - HloInstruction::CreateParameter(0, param_shape, "param")); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); - return module->AddEmbeddedComputation(builder.Build()); +void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( + int num_iters) { + string hlo_string_template = R"( + HloModule SimpleLoopWithIndirectLoopBound + SimpleLoopWithIndirectLoopBound.body { + loop_var.1 = (s32[], s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + limit = s32[] get-tuple-element(loop_var.1), index=2 + ROOT tuple = (s32[], s32[3]{0}, s32[]) tuple(add, multiply, limit) + } + SimpleLoopWithIndirectLoopBound.condition { + loop_var.2 = (s32[], s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=2 + ROOT less-than = pred[] less-than(get-tuple-element.3, get-tuple-element.4) + } + ENTRY SimpleLoopWithIndirectLoopBound { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + constant.2 = s32[] constant({{LOOP_BOUND}}) + tuple.1 = (s32[], s32[3]{0}, s32[]) tuple(constant.3, constant.4, + constant.2) + ROOT while = (s32[], s32[3]{0}, s32[]) while(tuple.1), + condition=SimpleLoopWithIndirectLoopBound.condition, + body=SimpleLoopWithIndirectLoopBound.body + } + )"; + + string hlo_string = tensorflow::str_util::StringReplace( + hlo_string_template, "{{LOOP_BOUND}}", + tensorflow::strings::StrCat(42 + num_iters), + /*replace_all=*/true); + ParseAndVerifyModule(hlo_string); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module()); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), +TEST_F(WhileLoopSimplifierTest, LoopWithZeroIterationSimiplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/0); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), +TEST_F(WhileLoopSimplifierTest, + LoopWithZeroIterationTupleElementLoopBoundSimplified) { + MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/0); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), + op::Tuple(op::Constant(), op::Constant(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithOneIterationSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), op::Tuple(op::Add(), op::Multiply())); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) { - MakeSimpleLoop(/*num_iters=*/2, &module()); +TEST_F(WhileLoopSimplifierTest, + LoopWithOneIterationTupleELementLoopBoundSimplified) { + MakeModuleWithSimpleLoopTupleElementLoopBound(/*num_iters=*/1); + HloModule* the_module = &module(); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + EXPECT_THAT(the_module->entry_computation()->root_instruction(), + op::Tuple(op::Add(), op::Multiply(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithTwoIterationsNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/2); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, + LoopWithControlDependencySimplifiedDependencyPreserved) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* true_op = while_op->while_body()->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(true))); TF_ASSERT_OK(true_op->AddControlDependencyTo( while_op->while_body()->root_instruction())); - ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); EXPECT_THAT(computation->root_instruction()->control_predecessors(), ElementsAre(op::Constant())) << computation->ToString(); @@ -139,8 +168,10 @@ TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { // Loops that contain send/recv nodes can't be simplified; the loop structure // around send/recv nodes must be preserved. -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -149,11 +180,13 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { HloInstruction::CreateConstant(Literal::CreateR0(true))), /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateSendDone(send)); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); @@ -161,247 +194,217 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0)); while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } // The limitation on not being able to simplify loops that contain infeeds (and // other non-removable instructions) isn't fundamental -- it just stems from the // fact that our infrastructure sees simplifying such a loop as tantamount to // removing the non-removable instruction. -TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); +TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) { + MakeModuleWithSimpleLoop(/*num_iters=*/1); + HloModule* the_module = &module(); + HloComputation* computation = the_module->entry_computation(); auto* while_op = computation->root_instruction(); ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); auto* while_body = while_op->while_body(); while_body->AddInstruction( HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); - EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); } -// Check that we don't crash when given a loop whose shape is not a tuple. -TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(100))))); - condition = module().AddEmbeddedComputation(cond_builder.Build()); +// A non-tuple shaped loop shouldn't be simplified or crash the compiler. +TEST_F(WhileLoopSimplifierTest, NonTupleShapedLoopNotSimplified) { + const string hlo_string = R"( + HloModule NonTupleShapedLoop + NonTupleShapedLoop.body { + loop_var.1 = s32[] parameter(0) + constant.1 = s32[] constant(-1) + ROOT add = s32[] add(s32[] loop_var.1, s32[] constant.1) + } + NonTupleShapedLoop.condition { + loop_var = s32[] parameter(0) + constant = s32[] constant(100) + ROOT less-than = pred[] less-than(s32[] loop_var, s32[] constant) + } + ENTRY INonTupleShapedLoop { + constant.2 = s32[] constant(42) + ROOT while = s32[] while(s32[] constant.2), + condition=NonTupleShapedLoop.condition, + body=NonTupleShapedLoop.body } + )"; - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))))); - body = module().AddEmbeddedComputation(body_builder.Build()); - } - - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } -// Construct a loop where we swap the tuple elements in each iteration. -// Although the tuple elements aren't used in the loop, we don't eliminate them, -// because the swapping side-effect is visible to users of the loop. -TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - body_builder.AddInstruction(HloInstruction::CreateTuple({ - body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)), - body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)), - })); - body = module().AddEmbeddedComputation(body_builder.Build()); +// A while loop that does nothing else besides swapping tuple elements +// can't be simplified as the result of the swapping is visible to users of the +// loop. +TEST_F(WhileLoopSimplifierTest, LoopSwappingTupleElementsNotSimplified) { + const string hlo_string = R"( + HloModule SwappingTupleElements + SwappingTupleElements.body { + loop_var = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) loop_var),index=1 + get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[]) loop_var), + index=0 + ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element, + s32[] get-tuple-element.1) } + SwappingTupleElements.always_true { + param = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY SwappingTupleElements { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y) + ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1), + condition=SwappingTupleElements.always_true, + body=SwappingTupleElements.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // Construct a loop where we assign a constant to tuple element 0 in each // iteration. We can't eliminate tuple element 0, even though we never use its // value. -TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0)))})); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateTuple({ - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, + LoopWithUnusedButModifiedTupleElementNotSimplified) { + const string hlo_string = R"( + HloModule UnusedButModifiedTupleElement + UnusedButModifiedTupleElement.body { + loop_var = (s32[]) parameter(0) + constant.1 = s32[] constant(1) + ROOT tuple = (s32[]) tuple(s32[] constant.1) } + UnusedButModifiedTupleElement.always_true { + param = (s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY UnusedButModifiedTupleElement { + constant.2 = s32[] constant(0) + tuple.1 = (s32[]) tuple(s32[] constant.2) + ROOT while = (s32[]) while((s32[]) tuple.1), + condition=UnusedButModifiedTupleElement.always_true, + body=UnusedButModifiedTupleElement.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // Nothing to simplify in a while loop whose tuple has 0 elements. -TEST_F(WhileLoopSimplifierTest, EmptyTuple) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({})); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - body_builder.AddInstruction(HloInstruction::CreateTuple({})); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, LoopWithEmptyTupleNotSimplified) { + const string hlo_string = R"( + HloModule EmptyTuple + EmptyTuple.body { + loop_var = () parameter(0) + ROOT tuple = () tuple() + } + EmptyTuple.always_true { + param = () parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY EmptyTuple { + tuple.1 = () tuple() + ROOT while = () while(() tuple.1), condition=EmptyTuple.always_true, + body=EmptyTuple.body } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // While loop where one tuple element is used twice in the body, and thus can't // be simplified away. -TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))), - })); - - HloComputation* condition = - MakeAlwaysTrueComputation(loop_init->shape(), &module()); - - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto* param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "param0")); - auto* gte0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); - // get0 is used twice in the loop body's tuple. - body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0})); - body = module().AddEmbeddedComputation(body_builder.Build()); +TEST_F(WhileLoopSimplifierTest, LoopWithElemUsedTwiceNotSimplified) { + const string hlo_string = R"( + HloModule ElemUsedTwice + ElemUsedTwice.body { + param0 = (s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[]) param0), index=0 + ROOT tuple = (s32[], s32[]) tuple(s32[] get-tuple-element, + s32[] get-tuple-element) + } + ElemUsedTwice.always_true { + param = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) } + ENTRY ElemUsedTwice { + x = s32[] parameter(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[]) tuple(s32[] x, s32[] y) + ROOT while = (s32[], s32[]) while((s32[], s32[]) tuple.1), + condition=ElemUsedTwice.always_true, body=ElemUsedTwice.body + } + )"; - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - module().AddEntryComputation(builder.Build()); + ParseAndVerifyModule(hlo_string); EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } // This while loop has three tuple elements. Element 0 is unused and should be // removed. Element 1 is used by the loop body, and element 2 is used by the // loop condition; these two should stay. -TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { - HloComputation::Builder builder(TestName()); - auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - })); - auto loop_shape = loop_init->shape(); - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".loop_condition"); - auto param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_shape, "param0")); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0))), - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - scalar_s32, param, /*index=*/2)))); - condition = module().AddEmbeddedComputation(cond_builder.Build()); +TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) { + const string hlo_string = R"( + HloModule RemoveUnusedOperands + RemoveUnusedOperands.body { + loop_var = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element((s32[], s32[], + s32[]) loop_var), index=0 + get-tuple-element.2 = s32[] get-tuple-element((s32[], s32[], + s32[]) loop_var), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(s32[] get-tuple-element.2, s32[] constant.1) + get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[]) + loop_var), index=2 + ROOT tuple = (s32[], s32[], s32[]) tuple(s32[] get-tuple-element.1, + s32[] add, s32[] get-tuple-element.3) } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto* param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_shape, "loop_var")); - - auto* tuple0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); - auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary( - scalar_s32, HloOpcode::kAdd, - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - scalar_s32, param, /*index=*/1)), - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto* tuple2 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({tuple0, tuple1, tuple2})); - - body = module().AddEmbeddedComputation(body_builder.Build()); + RemoveUnusedOperands.loop_condition { + constant.2 = s32[] constant(0) + param0 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element = s32[] get-tuple-element((s32[], s32[], s32[]) param0), + index=2 + ROOT equal-to = pred[] equal-to(s32[] constant.2, s32[] get-tuple-element) } + ENTRY RemoveUnusedOperands { + x = s32[] parameter(0) + constant.3 = s32[] constant(0) + y = s32[] parameter(1) + tuple.1 = (s32[], s32[], s32[]) tuple(s32[] x, s32[] constant.3, + s32[] y) + ROOT while = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) tuple.1), + condition=RemoveUnusedOperands.loop_condition, + body=RemoveUnusedOperands.body + } + )"; + + ParseAndVerifyModule(hlo_string); + HloModule* the_module = &module(); + EXPECT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie()); + + // The original while instruction is still left in the module as a dead + // instruction, find a while instruction with a different name as the new + // while instruction. + HloInstruction* new_while_op = + *std::find_if(the_module->entry_computation()->instructions().begin(), + the_module->entry_computation()->instructions().end(), + [&](const HloInstruction* instr) { + return (instr->opcode() == HloOpcode::kWhile && + instr->name() != "while"); + }); - auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - module().AddEntryComputation(builder.Build()); - EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); - - // We leave most of the checking to HloVerifiedTestBase, which runs the - // verifier on module() at the end of this test. - HloInstruction* new_while_op = *std::find_if( - module().entry_computation()->instructions().begin(), - module().entry_computation()->instructions().end(), - [&](const HloInstruction* instr) { - return instr != while_op && instr->opcode() == HloOpcode::kWhile; - }); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); EXPECT_TRUE( ShapeUtil::Equal(new_while_op->shape(), ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}))) @@ -418,31 +421,91 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } -TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { - auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); - Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); - - HloComputation* while_body = [&]() { - HloComputation::Builder builder(TestName() + ".passthrough"); - HloInstruction* param = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "param")); - HloComputation* result = module().AddEmbeddedComputation(builder.Build()); - - result->AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); - return result; - }(); - - HloComputation::Builder builder(TestName()); - auto* init_value = builder.AddInstruction( - HloInstruction::CreateParameter(0, while_shape, "init_value")); - builder.AddInstruction(HloInstruction::CreateWhile( - while_shape, MakeAlwaysTrueComputation(while_shape, &module()), - while_body, init_value)); - module().AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, - WhileLoopSimplifier{}.Run(&module())); - EXPECT_FALSE(simplified_loop); +TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) { + const string hlo_string = R"( + HloModule BodyHasNonTupleRoot + BodyHasNonTupleRoot.passthrough { + ROOT param = (s32[], s32[]) parameter(0) + } + BodyHasNonTupleRoot.always_true { + param.1 = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) + } + ENTRY BodyHasNonTupleRoot { + init_value = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while((s32[], s32[]) init_value), + condition=BodyHasNonTupleRoot.always_true, + body=BodyHasNonTupleRoot.passthrough + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, + LoopWithNonTupleBodyRootInstructionNotSimplified) { + const string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT custom-call = (s32[], s32[3]{0}) custom-call(add, multiply), + custom_call_target="x" + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(44) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) { + const string hlo_string = R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s32[3]{0} get-tuple-element(loop_var.1), index=2 + add.2 = s32[3]{0} add(get-tuple-element.2, get-tuple-element.3) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, add.2, get-tuple-element.3) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}, s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(47) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(42) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4, constant.4) + ROOT while = (s32[], s32[3]{0}, s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + } + )"; + + ParseAndVerifyModule(hlo_string); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index e20b25e4a08a946f6b58575a4d4e557744f8035c..bd0794184328b7926543c4275b3b915f51e7b812 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,18 +15,21 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { +using tensorflow::strings::StrCat; + static StatusOr WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { const Shape& narrow_shape = narrow_condition->parameter_instruction(0)->shape(); HloComputation* wide_while_cond = [&]() { - HloComputation::Builder builder( - tensorflow::strings::StrCat("wide.", narrow_condition->name())); + HloComputation::Builder builder(StrCat("wide.", narrow_condition->name())); builder.AddInstruction( HloInstruction::CreateParameter(0, wide_shape, "wide_param")); @@ -57,8 +60,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); HloComputation* wide_while_body = [&]() { - HloComputation::Builder builder( - tensorflow::strings::StrCat("wide.", narrow_body->name())); + HloComputation::Builder builder(StrCat("wide.", narrow_body->name())); builder.AddInstruction( HloInstruction::CreateParameter(0, wide_shape, "wide_param")); return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); @@ -137,4 +139,109 @@ WhileUtil::MakeInstructionsLiveIn( return std::move(result); } + +static StatusOr> +MakeCountedLoopConditionComputation(const Shape& loop_state_shape, + int32 trip_count) { + Shape scalar_pred = ShapeUtil::MakeShape(PRED, {}); + + TF_ASSIGN_OR_RETURN(std::unique_ptr cond_computation, + CreateComputationWithSignature( + {&loop_state_shape}, scalar_pred, "while_cond")); + + HloInstruction* trip_count_constant = cond_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(trip_count))); + + HloInstruction* param = cond_computation->parameter_instruction(0); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + + TF_ASSIGN_OR_RETURN( + HloInstruction * compare, + MakeBinaryHlo(HloOpcode::kLt, indvar, trip_count_constant)); + cond_computation->set_root_instruction(compare); + return std::move(cond_computation); +} + +static StatusOr> MakeCountedLoopBodyComputation( + const Shape& loop_state_shape, + const std::function( + HloInstruction*, const WhileUtil::LoopStateTy&)>& loop_body_generator) { + TF_ASSIGN_OR_RETURN(std::unique_ptr body_computation, + CreateComputationWithSignature( + {&loop_state_shape}, loop_state_shape, "while_body")); + HloInstruction* one = body_computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + HloInstruction* param = body_computation->parameter_instruction(0); + TF_ASSIGN_OR_RETURN(HloInstruction * indvar, + MakeGetTupleElementHlo(param, 0)); + TF_ASSIGN_OR_RETURN(HloInstruction * next_indvar, + MakeBinaryHlo(HloOpcode::kAdd, indvar, one)); + + std::vector loop_body_generator_args; + for (int64 i = 1, e = loop_state_shape.tuple_shapes_size(); i < e; i++) { + TF_ASSIGN_OR_RETURN(HloInstruction * tuple_element, + MakeGetTupleElementHlo(param, i)); + loop_body_generator_args.push_back(tuple_element); + } + TF_ASSIGN_OR_RETURN(std::vector next_state, + loop_body_generator(indvar, loop_body_generator_args)); + next_state.insert(next_state.begin(), next_indvar); + HloInstruction* next_state_tuple = + body_computation->AddInstruction(HloInstruction::CreateTuple(next_state)); + body_computation->set_root_instruction(next_state_tuple); + + return std::move(body_computation); +} + +static StatusOr MakeInitTupleFromInitValues( + HloComputation* computation, const WhileUtil::LoopStateTy& init_values) { + std::vector init_values_with_indvar; + init_values_with_indvar.reserve(init_values.size() + 1); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + init_values_with_indvar.push_back(zero); + c_copy(init_values, std::back_inserter(init_values_with_indvar)); + return computation->AddInstruction( + HloInstruction::CreateTuple(init_values_with_indvar)); +} + +static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { + std::vector loop_state_shape_components; + loop_state_shape_components.reserve(init_values.size() + 1); + loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); + c_transform(init_values, std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); + return ShapeUtil::MakeTupleShape(loop_state_shape_components); +} + +/*static*/ StatusOr WhileUtil::MakeCountedLoop( + HloComputation* computation, int32 trip_count, + const WhileUtil::LoopStateTy& init_values, + const WhileUtil::LoopBodyGeneratorTy& loop_body_generator) { + CHECK_GE(trip_count, 0); + + Shape loop_state_shape = MakeLoopStateShape(init_values); + TF_ASSIGN_OR_RETURN( + std::unique_ptr cond, + MakeCountedLoopConditionComputation(loop_state_shape, trip_count)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr body, + MakeCountedLoopBodyComputation(loop_state_shape, loop_body_generator)); + TF_ASSIGN_OR_RETURN(HloInstruction * init_tuple, + MakeInitTupleFromInitValues(computation, init_values)); + HloModule* module = computation->parent(); + HloInstruction* while_instr = + computation->AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, module->AddEmbeddedComputation(std::move(cond)), + module->AddEmbeddedComputation(std::move(body)), init_tuple)); + + std::vector result; + for (int64 i = 0, e = init_values.size(); i < e; i++) { + TF_ASSIGN_OR_RETURN(HloInstruction * user_state, + MakeGetTupleElementHlo(while_instr, i + 1)); + result.push_back(user_state); + } + return result; +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 3600b5a80d26e37fdb7d5173c3b8743734306390..1688d4674269c36c5b356f262dbd5d958572e101 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -52,6 +52,28 @@ class WhileUtil { static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); + + using LoopStateTy = std::vector; + using LoopBodyGeneratorTy = std::function( + HloInstruction* /*induction_var*/, + const LoopStateTy& /*current_values*/)>; + + // Creates a while loop in `computation` that runs for `trip_count` + // iterations. The structure of the while loop is as follows, in pseudocode: + // + // loop_state while_loop() { + // indvar = 0; + // loop_state = init_values + // while (indvar < trip_count) { + // loop_state = loop_body_generator(loop_state) + // indvar++; + // } + // return loop_state; + // } + static StatusOr MakeCountedLoop( + HloComputation* computation, int32 trip_count, + const LoopStateTy& init_values, + const LoopBodyGeneratorTy& loop_body_generator); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 063e312df66ce9cba0fa9f49c2fc6026ba6b74aa..8763e588c484011ba2ccbc7cad8f29817347a605 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -// HLO pass that replaces zero sized Hlos with an zero sized constant literal. +// HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 809941d8fe1f63d66bf104e66eea66167a0f509d..5b44c26b7c7b082556d9533cf3b3b1b98e5e4b09 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -54,9 +54,16 @@ class ServiceInterface { virtual tensorflow::Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) = 0; + virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; + virtual tensorflow::Status ExecuteParallel( const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual tensorflow::Status ExecuteGraphParallel( + const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; + virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) = 0; @@ -69,6 +76,10 @@ class ServiceInterface { virtual tensorflow::Status GetComputationStats( const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + virtual tensorflow::Status GetComputationGraphStats( + const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) = 0; + virtual tensorflow::Status GetComputationShape( const GetComputationShapeRequest* arg, GetComputationShapeResponse* result) = 0; @@ -101,6 +112,10 @@ class ServiceInterface { virtual tensorflow::Status ComputeConstant( const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual tensorflow::Status ComputeConstantGraph( + const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; + // Methods used by Computation. virtual tensorflow::Status SnapshotComputation( const SnapshotComputationRequest* ag, diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 280f02e88675381bd75108bfae0dd22c462ba718..ffaa40c2d673a2365342371ed8dab59565d1d08f 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -53,7 +53,7 @@ struct ShapeTreeNode { ShapeTreeNode(const ShapeTreeNode& other) : data(other.data), children(other.children.size()) { for (size_t i = 0; i < children.size(); ++i) { - children[i] = MakeUnique(*other.children[i]); + children[i] = ::xla::MakeUnique(*other.children[i]); } } @@ -62,7 +62,7 @@ struct ShapeTreeNode { data = other.data; children.resize(other.children.size()); for (size_t i = 0; i < children.size(); ++i) { - children[i] = MakeUnique(*other.children[i]); + children[i] = ::xla::MakeUnique(*other.children[i]); } } return *this; @@ -445,7 +445,7 @@ class ShapeTreeIterator : public std::iterator(index, node_->data); + current_ = ::xla::MakeUnique(index, node_->data); return *current_; } @@ -492,7 +492,7 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { template ShapeTree::ShapeTree(Shape shape) : root_(), - shape_storage_(MakeUnique(std::move(shape))), + shape_storage_(::xla::MakeUnique(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. @@ -508,7 +508,7 @@ ShapeTree::ShapeTree(const Shape* shape) : root_(), shape_(shape) { template ShapeTree::ShapeTree(Shape shape, const T& init_value) : root_(init_value), - shape_storage_(MakeUnique(std::move(shape))), + shape_storage_(::xla::MakeUnique(std::move(shape))), shape_(shape_storage_.get()) { // The shape_ field is just used to hold the structure of the shape. // It should not be relied upon to store layout information. diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 604e0173e789348923316174873f58058eaf2815..ac7e201bfdceabdd0f11db61bbb3b460017401ca 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -502,11 +502,11 @@ namespace { StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { tensorflow::str_util::RemoveLeadingWhitespace(s); - if (s->Consume("(")) { // Tuple. + if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; bool must_end = false; while (true) { - if (s->Consume(")")) { + if (tensorflow::str_util::ConsumePrefix(s, ")")) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", @@ -515,7 +515,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !s->Consume(","); + must_end = !tensorflow::str_util::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } @@ -609,6 +609,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, const Shape& rhs) { + CHECK(ShapeUtil::IsArray(lhs)); + CHECK(ShapeUtil::IsArray(rhs)); return ContainersEqual(lhs.dimensions(), rhs.dimensions()); } @@ -617,7 +619,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); } - return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); + if (lhs.element_type() == OPAQUE) { + return rhs.element_type() == OPAQUE; + } + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -627,7 +632,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); } - return SameDimensions(lhs, rhs); + if (lhs.element_type() == OPAQUE) { + return rhs.element_type() == OPAQUE; + } + return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, @@ -637,6 +645,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); } + if (lhs.element_type() == OPAQUE) { + return rhs.element_type() == OPAQUE; + } if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return CompatibleIgnoringElementType(lhs, rhs); } @@ -813,6 +824,18 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return new_shape; } +/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape, + ShapeIndexView index) { + const Shape* subshape = &shape; + for (auto i : index) { + if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size()) { + return false; + } + subshape = &subshape->tuple_shapes(i); + } + return true; +} + /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; @@ -1073,9 +1096,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping) { - // Can't insert bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) && - !LayoutUtil::HasLayout(output_shape)) { + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); + + if (!SameElementType(input_shape, output_shape)) { return false; } @@ -1106,9 +1130,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - // Can't convert reshapes into bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) || - !LayoutUtil::HasLayout(output_shape)) { + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); + + if (!SameElementType(input_shape, output_shape)) { return false; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 19b1aa93bd373ebd5f502d0dca56c9b31ab4fd7f..63da9154cfc1a5e7e8c0eeaa103d27096540fefe 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -24,11 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -208,6 +211,7 @@ class ShapeUtil { // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. + // Precondition: IsArray(lhs) && IsArray(rhs) static bool SameDimensions(const Shape& lhs, const Shape& rhs); // Returns whether the lhs and rhs shapes have the same element type. @@ -315,11 +319,25 @@ class ShapeUtil { // Returns an empty tuple shape. Can be used to indicate side-effects. static Shape MakeNil() { return MakeTupleShape({}); } + // Checks whether the shape is initialized. + static bool IsInitialized(const Shape& shape) { + return shape.element_type() != PRIMITIVE_TYPE_INVALID; + } + // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions); + // Creates a Shape with element type corresponding to T and the given + // dimensions + template + static Shape MakeShapeWithType( + tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + dimensions); + } + // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). static Shape MakeShapeWithLayout( @@ -430,6 +448,9 @@ class ShapeUtil { static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); + // Returns true if the given shape has a subshape at the given index. + static bool IndexIsValid(const Shape& shape, ShapeIndexView index); + // GetSubshape and GetMutableSubshape return a particular nested Shape within // the given Shape argument. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); @@ -522,12 +543,16 @@ class ShapeUtil { // Returns whether a transpose from input_shape to output_shape with dimension // mapping "dimension_mapping" produces a result which is bit-wise identical // to its input and thus may be replaced with a bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); @@ -560,16 +585,84 @@ class ShapeUtil { // The visitor_function visitor function should return true if it wants to // continue, or false otherwise. // - // visitor_function must be a callable of type bool(const std::vector&) - // or compatible. + // visitor_function must be a callable of type + // StatusOr(ArraySlice) or compatible. + template + static Status ForEachIndexWithStatus(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { + return ForEachIndexInternal(shape, base, count, incr, visitor_function); + } + + // Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus. + struct IndexIterationSpace { + std::vector index_base; + std::vector index_count; + std::vector index_incr; + }; + + template + static Status ForEachIndexWithStatus( + const Shape& shape, const IndexIterationSpace& iteration_space, + FnTy&& function) { + return ShapeUtil::ForEachIndexWithStatus( + shape, iteration_space.index_base, iteration_space.index_count, + iteration_space.index_incr, std::forward(function)); + } + template static void ForEachIndex(const Shape& shape, tensorflow::gtl::ArraySlice base, tensorflow::gtl::ArraySlice count, tensorflow::gtl::ArraySlice incr, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, base, count, incr, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + + // A parallel version of ForEachIndex(WithStatus). This can only be used if + // the visitor_function is thread-safe and the order of iteration does not + // matter. + // + // visitor_function must be a callable of type + // void(ArraySlice) or compatible. + template + static void ForEachIndexParallel(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function) { + // The parallel version of ForEachIndexInternal can never fail. + CHECK(ForEachIndexInternal( + shape, base, count, incr, + [&visitor_function](tensorflow::gtl::ArraySlice indexes) + -> StatusOr { + visitor_function(indexes); + return true; + }, + /*parallel=*/true) + .ok()); + } + + private: + // Validates all of the non-layout properties of the shape -- this is a helper + // used by both the layout-optional and layout-required public method. + static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); + + template + static Status ForEachIndexInternal(const Shape& shape, + tensorflow::gtl::ArraySlice base, + tensorflow::gtl::ArraySlice count, + tensorflow::gtl::ArraySlice incr, + const FnType& visitor_function, + bool parallel = false) { if (ShapeUtil::HasZeroElements(shape)) { - return; + return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); @@ -579,7 +672,22 @@ class ShapeUtil { // once with the proper empty indexes. int64 n = -1; std::vector indexes(base.begin(), base.end()); - while (n < rank && visitor_function(indexes)) { + const int kNumThreads = tensorflow::port::NumSchedulableCPUs(); + tensorflow::gtl::optional pool; + if (parallel) { + pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); + } + + while (n < rank) { + if (pool != tensorflow::gtl::nullopt) { + pool->Schedule( + [indexes, &visitor_function] { visitor_function(indexes); }); + } else { + TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes)); + if (!should_continue) { + break; + } + } // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { int64 dim = LayoutUtil::Minor(shape.layout(), n); @@ -590,12 +698,9 @@ class ShapeUtil { indexes[dim] = base[dim]; } } - } - private: - // Validates all of the non-layout properties of the shape -- this is a helper - // used by both the layout-optional and layout-required public method. - static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape); + return Status::OK(); + } TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4db97d45b20b86dc60531845c6e28a223203ff7f..13582a2a2678548dfc8e9c329dfb6def9d51fc9d 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -238,6 +238,18 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleScalarVsTuple) { + Shape shape1 = ShapeUtil::MakeShape(F32, {}); + Shape shape2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(U32, {})}); + EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1)); +} + TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -573,10 +585,11 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = [&invocations](const std::vector& indexes) { - invocations++; - return true; - }; + auto increment_func = + [&invocations](tensorflow::gtl::ArraySlice indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -588,6 +601,47 @@ TEST(ShapeUtilTest, ForEachIndex) { } } +TEST(ShapeUtilTest, ForEachIndexWithStatus) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); + // Increments at every invocation. + int invocations = 0; + auto increment_func = + [&invocations]( + tensorflow::gtl::ArraySlice indexes) -> StatusOr { + if (++invocations == 5) { + return Unimplemented("Cannot increment beyond 5."); + } + return true; + }; + + Status error_status = ShapeUtil::ForEachIndexWithStatus( + shape, /*base=*/{0, 0}, /*count=*/{10, 10}, /*incr=*/{0, 1}, + increment_func); + + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT(error_status.error_message(), + ::testing::HasSubstr("Cannot increment beyond 5.")); + EXPECT_EQ(invocations, 5); +} + +TEST(ShapeUtilTest, ForEachIndexParallel) { + Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); + int64 output[10][10]; + int init = 5; + auto set_func = [&](tensorflow::gtl::ArraySlice indexes) { + output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; + }; + + ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 10}, + /*incr=*/{1, 1}, set_func); + + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < 10; ++j) { + EXPECT_EQ(output[i][j], init + i + j); + } + } +} + TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) { // All output dimensions should be unmodified. One of the input dimensions is // modified because the input rank is larger by one. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 8339d08ef4d7455f9739b80074ab0405a404e8e8..1f90a44d8ba725c1bc7d23b581161f8915ff74fd 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:test", ], + alwayslink = True, ) cc_library( @@ -138,6 +139,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -188,6 +190,9 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -271,6 +276,9 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -290,6 +298,9 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -309,6 +320,9 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -325,6 +339,9 @@ xla_test( xla_test( name = "while_test", srcs = ["while_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -332,10 +349,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -366,9 +383,13 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -394,6 +415,8 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -430,6 +453,9 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:computation_builder", @@ -444,6 +470,9 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -460,11 +489,13 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -476,6 +507,7 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -532,6 +564,9 @@ xla_test( xla_test( name = "deconstruct_tuple_test", srcs = ["deconstruct_tuple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -568,6 +603,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -608,9 +644,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -622,8 +658,10 @@ xla_test( xla_test( name = "dot_operation_test", srcs = ["dot_operation_test.cc"], + shard_count = 20, tags = [ "enable_for_xla_interpreter", + "optonly", ], deps = [ "//tensorflow/compiler/xla:array2d", @@ -642,32 +680,21 @@ xla_test( ], ) -# Tests the dot operation in some cases that can be performed via a -# runtime call on some backends - e.g. a runtime call to Eigen. xla_test( - name = "dot_operation_runtime_test", - srcs = ["dot_operation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], + name = "gather_operation_test", + srcs = ["gather_operation_test.cc"], deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) -# Repeat dot_operation_runtime_test with single-threded eigen. +# Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", srcs = ["dot_operation_test.cc"], @@ -679,6 +706,8 @@ xla_test( "--xla_cpu_multi_thread_eigen=false", ], }, + shard_count = 20, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -699,6 +728,9 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -717,6 +749,9 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -747,10 +782,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -826,11 +861,11 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -877,8 +912,7 @@ xla_test( name = "half_test", srcs = ["half_test.cc"], backends = [ - # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25 - # "cpu", + "cpu", "gpu", ], deps = [ @@ -902,11 +936,14 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -918,6 +955,9 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -934,14 +974,16 @@ xla_test( name = "dynamic_ops_test", timeout = "moderate", srcs = ["dynamic_ops_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -960,6 +1002,9 @@ xla_test( xla_test( name = "tuple_test", srcs = ["tuple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -970,7 +1015,10 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -980,6 +1028,9 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -998,6 +1049,10 @@ xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], shard_count = 40, + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1013,6 +1068,8 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1035,10 +1092,11 @@ xla_test_library( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1075,11 +1133,11 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1091,6 +1149,9 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", @@ -1109,6 +1170,9 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1122,6 +1186,9 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1159,6 +1226,9 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1175,6 +1245,9 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1192,15 +1265,18 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1212,6 +1288,9 @@ xla_test( xla_test( name = "fmax_test", srcs = ["fmax_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1225,6 +1304,9 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1238,6 +1320,9 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1252,6 +1337,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1280,6 +1366,9 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1291,10 +1380,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1306,11 +1395,14 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1322,6 +1414,9 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", @@ -1345,6 +1440,9 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1352,9 +1450,9 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1365,8 +1463,12 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1382,11 +1484,14 @@ xla_test( xla_test( name = "bitcast_convert_test", srcs = ["bitcast_convert_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1421,6 +1526,9 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1447,6 +1555,8 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1467,6 +1577,8 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1504,6 +1616,9 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -1526,6 +1641,9 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1582,6 +1700,9 @@ xla_test( xla_test( name = "fusion_test", srcs = ["fusion_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1704,9 +1825,8 @@ tf_cc_test( deps = [ ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/core:test_main", @@ -1854,16 +1974,15 @@ tf_cc_test( ], ) -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +xla_test( + name = "test_utils_test", + srcs = ["test_utils_test.cc"], + deps = [ + ":local_client_test_base", + ":test_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 7e9005001db34d403ea923eb9c152d114bf32803..03c91745b978f80801e0da5ac44d31959659b20c 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,28 +51,28 @@ class ArrayElementwiseOpTestParamCount public ::testing::WithParamInterface {}; XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1, 324, std::numeric_limits::min(), std::numeric_limits::max()}); - auto result = builder.Neg(a); + builder.Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it @@ -83,28 +84,55 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); - auto result = builder.Neg(a); + builder.Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({ + -1, + 1, + 0, + 0x12345678, + static_cast(0xffffffff12345678l), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + }); + builder.Neg(a); + LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); + + ComputeAndCompareR1(&builder, + { + 1, + -1, + 0, + -0x12345678, + 0xedcba988, + static_cast(0x8000000000000000LL), + -static_cast(0x8000000000000001LL), + }, + {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } @@ -113,64 +141,63 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.IsFinite(builder.ConstantR0(NAN)); + XlaBuilder builder(TestName()); + builder.IsFinite(builder.ConstantR0(NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); - auto result_non_canonical = - builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); + builder.IsFinite(builder.ConstantR0(kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); - auto result_inf = builder.IsFinite(builder.ConstantR0(inf)); + builder.IsFinite(builder.ConstantR0(inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_neg_inf = builder.IsFinite(builder.ConstantR0(-inf)); + builder.IsFinite(builder.ConstantR0(-inf)); ComputeAndCompareR0(&builder, false, {}); - auto result_zero = builder.IsFinite(builder.ConstantR0(0.0f)); + builder.IsFinite(builder.ConstantR0(0.0f)); ComputeAndCompareR0(&builder, true, {}); } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); auto a = builder.ConstantR1( {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); - auto result = builder.IsFinite(a); + builder.IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); auto b = builder.ConstantR1( {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, @@ -178,17 +205,97 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { + ComputationBuilder b(client_, TestName()); + + std::vector lhs{0xFFFFFFFF, + static_cast(-1), + 0, + 0, + 0x7FFFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFLL, + 0x8000000000000000LL, + 0x8000000000000000LL, + 1}; + std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_data = + client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + + std::vector rhs{1, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 0x8000000000000000LL, + 0, + static_cast(-1), + 0, + 1, + 0x8000000000000000LL}; + std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_data = + client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + + b.Add(lhs_param, rhs_param); + + std::vector expected(lhs.size()); + for (int64 i = 0; i < lhs.size(); ++i) { + expected[i] = lhs[i] + rhs[i]; + } + + ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { + ComputationBuilder b(client_, TestName()); + + std::vector lhs{static_cast(0x8000000000000000LL), + static_cast(0x8000000000000000LL), + -1, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 1, + 0, + -1}; + std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_data = + client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + + std::vector rhs{-1, + 0, + static_cast(0x8000000000000000LL), + 1, + 0, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL}; + std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_data = + client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + + auto sub = b.Sub(lhs_param, rhs_param); + + std::vector expected(lhs.size()); + for (int64 i = 0; i < lhs.size(); ++i) { + expected[i] = lhs[i] - rhs[i]; + } + + ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector a_values; std::vector b_values; for (int i = 0; i < count; ++i) { @@ -227,49 +334,49 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 2, 1000000000}); auto b = builder.ConstantR1({-1, 2, 1, -1}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, @@ -277,29 +384,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Sub(a, b); + builder.Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -329,9 +436,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -344,8 +451,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -354,9 +461,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -369,8 +476,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { // Test with a compile-time constant divisor. { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -400,9 +507,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -414,8 +521,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Div(dividend, builder.ConstantR1(divisors)); @@ -424,9 +531,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = @@ -438,8 +545,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; + XlaBuilder builder(TestName()); + XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); builder.Rem(dividend, builder.ConstantR1(divisors)); @@ -449,33 +556,33 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); auto b = builder.ConstantR1( {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto div = builder.Div(a, b); + builder.Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); auto b = builder.ConstantR1( {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, @@ -483,21 +590,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); auto b = builder.ConstantR1( {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, @@ -505,20 +612,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -541,19 +648,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } @@ -572,21 +679,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(a_data); auto b = builder.ConstantR1(b_data); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); auto b = builder.ConstantR1( {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, @@ -594,378 +701,386 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto add = builder.Mul(a, b); + builder.Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, -8}); auto b = builder.ConstantR1({5, -7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); auto b = builder.ConstantR2({{1, -6}, {4, 5}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 12}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {3, 8}}); auto b = builder.ConstantR2({{1, 0}, {7, 6}}); - auto out = builder.And(a, b); + builder.And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.And(a, b); + builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, false}, {true, true}}); auto b = builder.ConstantR2({{false, true}, {false, true}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, -1, 8}); auto b = builder.ConstantR1({5, -7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, -1}, {8, 8}}); auto b = builder.ConstantR2({{5, -7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 1, 8}); auto b = builder.ConstantR1({5, 7, 4}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 1}, {8, 8}}); auto b = builder.ConstantR2({{5, 7}, {4, 1}}); - auto out = builder.Or(a, b); + builder.Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.Or(a, b); + builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, true}, {true, false}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-1, 0, 1}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, 4294967295}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); - auto out = builder.Not(a); + builder.Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto out = builder.Not(a); + builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { - ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x12345678), - static_cast(0xF0001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15}); - auto out = builder.ShiftLeft(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({static_cast(0x12345678), + static_cast(0xF0001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); + builder.ShiftLeft(a, b); - ComputeAndCompareR1( - &builder, - {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); + ComputeAndCompareR1(&builder, + {static_cast(0x23456780), 0x00100000, 0x4, + 0x180, 2523136, 0, 0, 0}, + {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { - ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2}); - auto out = builder.ShiftRightArithmetic(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); + builder.ShiftRightArithmetic(a, b); - ComputeAndCompareR1(&builder, - {static_cast(0xF9234567), - static_cast(0x00100010), 0, 0, 19}, - {}); + ComputeAndCompareR1( + &builder, + {static_cast(0xF9234567), static_cast(0x00100010), 0, 0, 19, + 0, -1, 0}, + {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { - ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5}); - auto out = builder.ShiftRightLogical(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); + builder.ShiftRightLogical(a, b); - ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); + ComputeAndCompareR1(&builder, + {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x12345678, 0xF0001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15}); - auto out = builder.ShiftLeft(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); + builder.ShiftLeft(a, b); ComputeAndCompareR1( - &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2}); - auto out = builder.ShiftRightArithmetic(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); + builder.ShiftRightArithmetic(a, b); - ComputeAndCompareR1(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); + ComputeAndCompareR1( + &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5}); - auto out = builder.ShiftRightLogical(a, b); + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); + builder.ShiftRightLogical(a, b); - ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); + ComputeAndCompareR1(&builder, + {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 2.25f, 10.0f, NAN}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, NAN}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } @@ -973,10 +1088,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -984,17 +1099,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1005,16 +1120,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } @@ -1023,7 +1138,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, @@ -1034,7 +1149,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } @@ -1043,10 +1158,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({10.0f, 25.5f, 1.0f, 10.0f, NAN}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } @@ -1054,10 +1169,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1066,10 +1181,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1078,10 +1193,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1091,10 +1206,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1103,10 +1218,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({min, min, min, 0, 0, 0, max, max, max}); auto rhs = builder.ConstantR1({min, 0, max, -1, 0, 1, min, 0, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1115,10 +1230,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Eq(lhs, rhs); + builder.Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, @@ -1127,10 +1242,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ne(lhs, rhs); + builder.Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); @@ -1138,10 +1253,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Ge(lhs, rhs); + builder.Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); @@ -1149,10 +1264,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Gt(lhs, rhs); + builder.Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, @@ -1161,10 +1276,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Le(lhs, rhs); + builder.Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); @@ -1172,10 +1287,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = builder.ConstantR1({0, 1, max, 4, 5, 6, 0, 1, max}); - auto compare = builder.Lt(lhs, rhs); + builder.Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, @@ -1184,12 +1299,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = builder.ConstantR1({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); @@ -1197,20 +1312,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({-2.0f, -0.6f, -0.6f, 0.0f}); auto rhs = builder.ConstantR1({0.5f, 0.6f, -0.6f, -0.6f}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Pow(lhs, rhs); + builder.Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } @@ -1484,14 +1599,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector values; values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } auto x = builder.ConstantR1(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); std::vector expected; expected.reserve(values.size()); @@ -1503,7 +1618,7 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 2, 2); std::vector values_vector; @@ -1517,140 +1632,86 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { Array4D expected(2, 2, 2, 2, expected_vector); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); auto x = builder.ConstantR4FromArray4D(values); - auto exp = builder.Pow(x, builder.ConstantR0(2.0f)); + builder.Pow(x, builder.ConstantR0(2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT -// such -// * fmin(NaN, x) = x -// * fmax(NaN, x) = x -// so we only test NAN on CPU. -// -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); -#endif - auto minimum = builder.Min(lhs, rhs); - - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {1.0f, -5.0f, 1.0f}, -#else - {1.0f, -5.0f, 1.0f, 10.0f, 6.0f}, -#endif - {}, error_spec_); + builder.Min(lhs, rhs); + + ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); -#endif - auto minimum = builder.Min(lhs, rhs); + builder.Min(lhs, rhs); - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {1.0, -5.0, 1.0}, -#else - {1.0, -5.0, 1.0, 10.0, 6.0}, -#endif - {}, error_spec_); + ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, + error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f}); - auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = builder.ConstantR1({2.0f, -5.0f, 1.0f, 10.0f, NAN}); -#endif - auto maximum = builder.Max(lhs, rhs); - - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {2.0f, 1.0f, 2.25f}, -#else - {2.0f, 1.0f, 2.25f, 10.0f, 6.0f}, -#endif - {}, error_spec_); + builder.Max(lhs, rhs); + + ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); - auto minimum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -// TODO(b/28180546): Make this compile in a way that is consistent -// among backends. See comment on MinF32s test above. XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { - ComputationBuilder builder(client_, TestName()); -#if !defined(XLA_TEST_BACKEND_CPU) - auto lhs = builder.ConstantR1({1.0, 1.0, 2.25}); - auto rhs = builder.ConstantR1({2.0, -5.0, 1.0}); -#else + XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = builder.ConstantR1({1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = builder.ConstantR1({2.0, -5.0, 1.0, 10.0, NAN}); -#endif - auto maximum = builder.Max(lhs, rhs); + builder.Max(lhs, rhs); - ComputeAndCompareR1(&builder, -#if !defined(XLA_TEST_BACKEND_CPU) - {2.0, 1.0, 2.25}, -#else - {2.0, 1.0, 2.25, 10.0, 6.0}, -#endif - {}, error_spec_); + ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, + error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1665,7 +1726,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = builder.ConstantR1( @@ -1679,7 +1740,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Max(x, y); @@ -1690,7 +1751,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 0, 1, 1, 1, max, max, max}); auto y = builder.ConstantR1({0, 1, 0, 1, 10, 0, 234234, max}); builder.Min(x, y); @@ -1700,7 +1761,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( @@ -1713,7 +1774,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR1({}); builder.Max(u, v); @@ -1723,7 +1784,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto u = builder.ConstantR1({3.5}); auto v = builder.ConstantR2FromArray2D(Array2D(0, 2)); builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); @@ -1733,7 +1794,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({2.0f, 3.0f, 4.0f}); auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); @@ -1744,7 +1805,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({}); auto m = builder.ConstantR2({{}, {}}); builder.Max(v, m, /*broadcast_dimensions=*/{1}); @@ -1754,7 +1815,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1765,7 +1826,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto scalar = builder.ConstantR0(2); Array3D a_3d(2, 0, 3); auto array = builder.ConstantR3FromArray3D(a_3d); @@ -1776,7 +1837,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); @@ -1787,7 +1848,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{}, {}}); auto v = builder.ConstantR1({-10.2f, 16.4f}); builder.Min(m, v, /*broadcast_dimensions=*/{0}); @@ -1797,7 +1858,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); auto array4d = builder.ConstantR4FromArray4D( @@ -1812,7 +1873,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto array2d = builder.ConstantR2({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); @@ -1824,7 +1885,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Min(x, y); @@ -1834,7 +1895,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = builder.ConstantR1({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); builder.Max(x, y); @@ -1844,110 +1905,107 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { } XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-3, 26, 2, -1, 1}); auto b = builder.ConstantR1({10, 5, 1, 10, -10}); - auto add = builder.Rem(a, b); + builder.Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); auto maximum = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto minimum = builder.ConstantR0(0.0f); auto argument = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto maximum = builder.ConstantR0(5.0f); - auto clamp = builder.Clamp(minimum, argument, maximum); + builder.Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); - auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + builder.Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min_scalar = builder.ConstantR0(0); auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); auto max_scalar = builder.ConstantR0(3); auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. - auto clamp = builder.Add( - builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), - builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), - builder.Clamp(min_scalar, arg_vector, max_scalar))); + builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -1961,7 +2019,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, @@ -1969,7 +2027,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); @@ -1983,7 +2041,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto add = builder.Add(p0, p1); + builder.Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( @@ -1991,7 +2049,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); @@ -2000,35 +2058,35 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto p = builder.Parameter(0, param0_literal->shape(), "param0"); - auto add = builder.Add(a, p); + builder.Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Cos(a); + builder.Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); - auto result = builder.Sin(a); + builder.Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); - auto atan = builder.Atan2(a, b); + builder.Atan2(a, b); ComputeAndCompareR1( &builder, @@ -2037,9 +2095,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { } XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); - auto result = builder.Tanh(a); + builder.Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); @@ -2049,7 +2107,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, @@ -2088,7 +2146,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. @@ -2124,7 +2182,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, @@ -2164,14 +2222,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // / / // b -----/ / // c---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({1.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(a, b); - auto add2 = builder.Add(add, c); + builder.Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2182,14 +2240,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // / / // c -----/ / // a---------------------/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto c = builder.ConstantR1({-3.3f, -15.5f, -7.7f, -29.9f}); auto add = builder.Add(b, c); - auto add2 = builder.Add(a, add); + builder.Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); @@ -2199,14 +2257,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // a ----- (neg) ----- (add) // / // b ----- (neg) ----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); auto neg_a = builder.Neg(a); auto neg_b = builder.Neg(b); - auto result = builder.Add(neg_a, neg_b); + builder.Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); @@ -2220,7 +2278,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // c ------ (add) ------------/ // / // d -----/ - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({91.1f, 2.2f, 3.3f, 4.4f}); auto b = builder.ConstantR1({2.1f, 3.2f, 4.3f, 5.4f}); @@ -2229,19 +2287,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { auto add_ab = builder.Add(a, b); auto add_cd = builder.Add(c, d); - auto add_all = builder.Add(add_ab, add_cd); + builder.Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2250,11 +2308,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(scalar, a); + builder.Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2262,11 +2320,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = builder.ConstantR0(3.0f); - auto add = builder.Add(a, scalar); + builder.Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2275,14 +2333,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f, 60.0f}); // clang-format off auto m = builder.ConstantR2({ {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2308,10 +2366,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({42, 73}); auto m = builder.ConstantR2({{42, 73}, {42, 52}}); - auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + builder.Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, @@ -2322,10 +2380,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + builder.Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, @@ -2336,10 +2394,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + builder.Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, @@ -2350,10 +2408,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); + builder.Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, @@ -2364,10 +2422,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, 2, 3, 4}); auto m = builder.ConstantR2({{1, 0, 5, 6}, {42, 52, 10, 4}}); - auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + builder.Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, @@ -2379,24 +2437,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantR2({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); auto v = builder.ConstantR1({2.0f, 4.0f, 6.0f}); - auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + builder.Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f, 20.0f, 30.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2404,14 +2462,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m auto m = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = builder.ConstantR2({{10.0f}, {20.0f}}); - auto add = builder.Add(m, md); + builder.Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); @@ -2422,13 +2480,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // effectively creates an "outer product" operation. // This is taken from the Numpy docs example at: // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. auto a = builder.ConstantR2({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); auto b = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto add = builder.Add(a, b); + builder.Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, @@ -2439,10 +2497,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + builder.Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } @@ -2450,17 +2508,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({20.0f, 40.0f}); auto m = builder.ConstantR2({{10.0f, 50.0f}, {77.0f, 88.0f}}); - auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); + builder.Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { // Binary add of two R3s together - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2468,7 +2526,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto add = builder.Add(a, b); + builder.Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, @@ -2479,7 +2537,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2492,7 +2550,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); + builder.Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, @@ -2503,7 +2561,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for // broadcasting (though there are two ways to broadcast these shapes). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2516,7 +2574,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // clang-format on auto a = builder.ConstantR3FromArray3D(a_3d); auto v = builder.ConstantR1({10.0f, 20.0f}); - auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); + builder.Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ @@ -2534,7 +2592,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} // for broadcasting. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, @@ -2549,7 +2607,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); - auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, @@ -2566,7 +2624,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { // Comparison between two 3D arrays of compatible shapes: // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = builder.ConstantR3FromArray3D(a_3d); @@ -2574,7 +2632,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); auto b = builder.ConstantR3FromArray3D(b_3d); - auto compare = builder.Gt(a, b); + builder.Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); @@ -2590,7 +2648,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { } XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> operand_b_4d(new Array4D(2, 3, 4, 5)); @@ -2611,13 +2669,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR4FromArray4D(*operand_b_4d); - auto add = builder.Add(a, b); + builder.Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); @@ -2639,7 +2697,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { auto a = builder.ConstantR4FromArray4D(*operand_a_4d); auto b = builder.ConstantR1(operand_b_1d); - auto add = builder.Add(a, b, {1}); + builder.Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } @@ -2654,7 +2712,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::vector r1(d1); std::iota(r1.begin(), r1.end(), 1.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = builder.ConstantLiteral(*a_literal); @@ -2675,11 +2733,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { // Show that we can't add two opaques. XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); auto x = builder.Parameter(0, shape, "x"); - auto concatenated = builder.Add(x, x); - StatusOr computation_status = builder.Build(); + builder.Add(x, x); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( @@ -2687,12 +2745,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); @@ -2700,14 +2758,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { } XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); - auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); - StatusOr computation_status = builder.Build(); + auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().error_message(), ::testing::ContainsRegex("must.*be the identity")); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 3f6fd7c65d3360a622dbf754833009fb20410535..ec3b46acfec0ee0ff514a862ce5b1ca74279efa8 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,11 +29,11 @@ namespace { class AxpySimpleTest : public ClientLibraryTestBase {}; TEST_F(AxpySimpleTest, AxTenValues) { - ComputationBuilder builder(client_, "ax_10"); + XlaBuilder builder("ax_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Mul(alpha, x); + builder.Mul(alpha, x); std::vector expected = { -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796, @@ -46,7 +47,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { auto x = builder.ConstantR1({}); auto y = builder.ConstantR1({}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -60,7 +61,7 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto y = builder.ConstantR1( {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0}); auto ax = builder.Mul(alpha, x); - auto axpy = builder.Add(ax, y); + builder.Add(ax, y); TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 28ab9654997728fbafd6610af840e721e72cce5a..f3dac75a44b948c4b45b80b93e7462073010979e 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -69,6 +69,15 @@ class BatchNormalizationTest CHECK_EQ(kY, input_array_.width()); } + XlaOp CheckShape(XlaBuilder* b, const XlaOp& operand, + const Shape& expected_shape) const { + Shape actual_shape = b->GetShape(operand).ConsumeValueOrDie(); + CHECK(ShapeUtil::Equal(expected_shape, actual_shape)) + << "want " << ShapeUtil::HumanString(expected_shape) << " got " + << ShapeUtil::HumanString(actual_shape); + return operand; + } + static constexpr int64 kSamples = 3; static constexpr int64 kX = 1; static constexpr int64 kY = 1; @@ -91,7 +100,7 @@ INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, #endif XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { - ComputationBuilder builder(client_, "subtract_in_z_one_sample"); + XlaBuilder builder("subtract_in_z_one_sample"); auto x = builder.ConstantLiteral(input_literal_); auto y = builder.ConstantR1({3.14, 4.25}); builder.Sub(x, y, /*broadcast_dimensions=*/{1}); @@ -107,7 +116,7 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { } XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { - ComputationBuilder builder(client_, "square_tesseract_elementwise"); + XlaBuilder builder("square_tesseract_elementwise"); auto x = builder.ConstantLiteral(input_literal_); builder.SquareF32(x); @@ -124,9 +133,9 @@ XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { } XLA_TEST_P(BatchNormalizationTest, SumToZ) { - ComputationBuilder builder(client_, "sum_to_z"); + XlaBuilder builder("sum_to_z"); auto input_activations = builder.ConstantLiteral(input_literal_); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all but the Z dimension. builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, {0, 2, 3}); @@ -136,24 +145,23 @@ XLA_TEST_P(BatchNormalizationTest, SumToZ) { } XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { - ComputationBuilder builder(client_, "square_and_reduce"); + XlaBuilder builder("square_and_reduce"); auto input_activations = builder.ConstantLiteral(input_literal_); auto set_means = builder.ConstantR1({2.f, 4.2f}); auto activation_deviations = builder.Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = builder.Reduce( - dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); + builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, {0, 2, 3}); std::vector expected = {18, 0.06}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); } XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { - ComputationBuilder builder(client_, "variance_to_stddev"); + XlaBuilder builder("variance_to_stddev"); auto variance = builder.ConstantR1({6.f, .02f}); - auto sqrt = builder.SqrtF32(variance); + builder.SqrtF32(variance); std::vector expected = {2.44948974f, 0.14142136f}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); @@ -162,23 +170,24 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { // Compare against a forward batch normalization example in the NN spec // reference. XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { - ComputationBuilder builder(client_, "batch_normalize_per_spec"); + XlaBuilder builder("batch_normalize_per_spec"); auto input_activations = - builder.CheckShape(builder.ConstantLiteral(input_literal_), - ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); + CheckShape(&builder, builder.ConstantLiteral(input_literal_), + ShapeUtil::MakeShape(F32, {3, 2, 1, 1})); auto gamma = builder.ConstantR1({1.0, 1.0}); auto beta = builder.ConstantR1({0.0, 0.0}); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); // Reduce all dimensions except dimension 1. Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2}); - auto sum = builder.CheckShape( + auto sum = CheckShape( + &builder, builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie(); auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie(); - auto count = builder.ConstantR0(ShapeUtil::ElementsIn(*input_shape) / - ShapeUtil::ElementsIn(*sum_shape)); + auto count = builder.ConstantR0(ShapeUtil::ElementsIn(input_shape) / + ShapeUtil::ElementsIn(sum_shape)); auto set_means = builder.Div(sum, count); const float kEpsilon = 1e-9f; @@ -187,14 +196,16 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { auto activation_deviations = builder.Sub(input_activations, set_means, /*broadcast_dimensions=*/{1}); auto dev_squares = builder.SquareF32(activation_deviations); - auto sum_of_squares = builder.CheckShape( + auto sum_of_squares = CheckShape( + &builder, builder.Reduce(dev_squares, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0, 2, 3}), TwoElementVectorF32); auto variance = builder.Div(sum_of_squares, count); auto standard_deviation = builder.SqrtF32(variance); - auto standard_deviation_above_epsilon = builder.CheckShape( - builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2})); + auto standard_deviation_above_epsilon = + CheckShape(&builder, builder.Gt(standard_deviation, epsilon), + ShapeUtil::MakeShape(PRED, {2})); auto gt_eps = builder.Select(standard_deviation_above_epsilon, standard_deviation, epsilon2); auto normalization_factors = builder.ReciprocalF32(gt_eps); @@ -219,7 +230,7 @@ XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { XLA_TEST_P(BatchNormalizationTest, BasicTraining) { const int kFeatureIndex = 3; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); @@ -228,8 +239,8 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { auto offset = builder.ConstantR1({1.0f, 2.0f}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, @@ -243,7 +254,7 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); @@ -252,8 +263,8 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { auto offset = builder.ConstantR1({1.0f, 2.0f}); - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); + builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, @@ -268,23 +279,23 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { // Use 0 dimension as feature, tests layout analyzer. const int kFeatureIndex = 0; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle h0; + XlaOp h0; auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), /*parameter_number=*/0, "operand", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto scale = CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; + XlaOp h2; auto offset = CreateR1Parameter(std::vector(260, 1.0f), /*parameter_number=*/2, "offset", &builder, &h2); - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); + builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) @@ -300,24 +311,24 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { // Test the correctness of choosing a large epsilon value. const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle h0; + XlaOp h0; auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, /*parameter_number=*/0, "operand", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto scale = CreateR1Parameter(std::vector(1, 1.0f), /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; + XlaOp h2; auto offset = CreateR1Parameter(std::vector(1, 0.0f), /*parameter_number=*/2, "offset", &builder, &h2); // var = 125, mean = 15, epsilon = -100 - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); + builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); auto expected = Literal::MakeTuple( {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) @@ -332,7 +343,7 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); @@ -439,7 +450,7 @@ INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes, XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { float epsilon = 0.001; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const std::vector& bounds = GetParam().bounds; Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); input_array.FillRandom(GetParam().random_value_var, @@ -539,7 +550,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { float epsilon = 0.001; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const std::vector& bounds = GetParam().bounds; Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); input_array.FillRandom(GetParam().random_value_var, @@ -647,7 +658,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { float epsilon = 0.001; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const std::vector& bounds = GetParam().bounds; Array4D input_array(bounds[0], bounds[1], bounds[2], bounds[3]); input_array.FillRandom(GetParam().random_value_var, @@ -814,9 +825,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { std::unique_ptr grad_output_data = client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); - auto t = builder.BatchNormGrad(input_parameter, scale_parameter, - mean_parameter, var_parameter, - grad_output_parameter, epsilon, feature_index); + builder.BatchNormGrad(input_parameter, scale_parameter, mean_parameter, + var_parameter, grad_output_parameter, epsilon, + feature_index); auto expected = Literal::MakeTuple({expected_grad_activation.get(), diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index 0d94d65c1015fb54ada3fdfc95d0c31d0a0f158b..777ac167a3c38c38791e12541a5db3078c37595b 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -42,7 +42,7 @@ class BitcastConvertTest : public ClientLibraryTestBase { }; TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.BitcastConvertType(a, S32); @@ -51,7 +51,7 @@ TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { } TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0f, 64.0f}); builder.BitcastConvertType(a, F32); @@ -60,7 +60,7 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { } TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, static_cast(0xBF800000), 0x3F000000, @@ -72,7 +72,7 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { } XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); builder.BitcastConvertType(a, F32); @@ -81,7 +81,7 @@ XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { } TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.6, 64.4}); builder.BitcastConvertType(a, S32); @@ -90,7 +90,7 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { } TEST_F(BitcastConvertTest, ConvertS32Extremes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {std::numeric_limits::min(), std::numeric_limits::max()}); builder.BitcastConvertType(a, F32); @@ -100,7 +100,7 @@ TEST_F(BitcastConvertTest, ConvertS32Extremes) { } TEST_F(BitcastConvertTest, ConvertMapToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->BitcastConvertType(param, S32); @@ -112,7 +112,7 @@ TEST_F(BitcastConvertTest, ConvertMapToS32) { } TEST_F(BitcastConvertTest, ConvertMapToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->BitcastConvertType(param, F32); @@ -129,7 +129,7 @@ TEST_F(BitcastConvertTest, ConvertMapToF32) { // input -> convert -> reshape // the new convert should have the same element type as the old convert. TEST_F(BitcastConvertTest, ConvertReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR1({0x42280000}); auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); builder.BitcastConvertType(reshape, F32); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 03f5e08315bfed2bcb43ebb7098aaa0b97228605..97095f1cc427789845051a8fea24c95475286fe2 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -662,7 +662,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("broadcast dimension 0 mismatch")); + HasSubstr("dimension 0 mismatch")); } XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { @@ -675,7 +675,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("binary op BINOP_ADD with incompatible shapes")); + HasSubstr("op BINOP_ADD with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -688,7 +688,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("binary op BINOP_ADD with incompatible shapes")); + HasSubstr("op BINOP_ADD with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 610302ac1256a57db6ed6e18016a4136973e3891..eac2eb286c3f7a1cd33aed03686e99ef753b773a 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -137,7 +137,8 @@ def xla_test(name, backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] this_backend_tags += ["requires-gpu-sm35"] elif backend in plugins: - backend_deps = plugins[backend]["deps"] + backend_deps = [] + backend_deps += plugins[backend]["deps"] this_backend_copts += plugins[backend]["copts"] this_backend_tags += plugins[backend]["tags"] this_backend_args += plugins[backend]["args"] diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index a677986cd926cc0054d8f36abc98ccac33dc043d..312d8f284d3421b4ef06b94c12949fc5fe4fa0b0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -35,6 +36,10 @@ namespace se = ::perftools::gputools; namespace xla { namespace { + +// Name of the interpreter backend. +constexpr char kInterpreter[] = "interpreter"; + // Wrapper function that creates a nicer error message (than a bare // ValueOrDie()) if the platform we intend to test is not available. Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { @@ -43,6 +48,14 @@ Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); } + +// Helper functions to get the reference platform. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + } // namespace ClientLibraryTestBase::ClientLibraryTestBase( @@ -66,6 +79,11 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) LocalClientOptions default_options; default_options.set_platform(platform); client_ = GetOrCreateLocalClientOrDie(default_options); + + LocalClientOptions ref_options; + ref_options.set_platform(GetReferencePlatform()); + ref_client_ = GetOrCreateLocalClientOrDie(ref_options); + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "constant_folding"); } @@ -74,9 +92,9 @@ string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } +template StatusOr> ClientLibraryTestBase::Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { + BuilderT* builder, tensorflow::gtl::ArraySlice arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); @@ -95,6 +113,20 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + ExecutionOptions execution_options = execution_options_; + if (shape_with_output_layout != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *shape_with_output_layout; + } + return client_->ExecuteAndTransfer(computation, arguments, + &execution_options); +} + +template <> StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments, @@ -104,6 +136,29 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } +template <> +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + +StatusOr> +ClientLibraryTestBase::ExecuteAndTransferReference( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + ExecutionOptions execution_options = execution_options_; + if (shape_with_output_layout != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *shape_with_output_layout; + } + return ref_client_->ExecuteAndTransfer(computation, arguments, + &execution_options); +} + std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { @@ -116,14 +171,31 @@ std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); } +string ClientLibraryTestBase::ExecuteToString( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + auto computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status().ToString(); + } + auto computation = computation_status.ConsumeValueOrDie(); + + auto result = + client_->ExecuteAndTransfer(computation, arguments, &execution_options_); + if (!result.ok()) { + return result.status().ToString(); + } else { + return result.ValueOrDie()->ToString(); + } +} + string ClientLibraryTestBase::ExecuteToString( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { - StatusOr computation_status = builder->Build(); + auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); } - Computation computation = computation_status.ConsumeValueOrDie(); + auto computation = computation_status.ConsumeValueOrDie(); auto result = client_->ExecuteAndTransfer(computation, arguments, &execution_options_); @@ -142,16 +214,18 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } +template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, @@ -249,8 +323,28 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice /*arguments*/, + const std::function& /*verify_output*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +tensorflow::Status +ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, + tensorflow::gtl::ArraySlice /*arguments*/, + const std::function& /*verify_output*/, + const Shape* /*output_with_layout*/) { + return Unimplemented("not yet implemented for XlaComputation"); +} + +template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -307,8 +401,9 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( return tensorflow::Status::OK(); } +template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -378,8 +473,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( EXPECT_EQ(expected, actual->GetR1U8AsString()); } +template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -390,8 +486,9 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectEqual(expected, *actual); } +template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -456,6 +553,69 @@ ClientLibraryTestBase::ComputeValueAndReference( return std::make_pair(std::move(reference), std::move(result)); } +void ClientLibraryTestBase::ComputeAndCompare( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + auto status_or_data = ComputeValueAndReference(builder, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*reference, *result); +} + +void ClientLibraryTestBase::ComputeAndCompare( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + ErrorSpec error) { + auto status_or_data = ComputeValueAndReference(builder, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectNear(*reference, *result, error); +} + +StatusOr, std::unique_ptr>> +ClientLibraryTestBase::ComputeValueAndReference( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + // Transfer the arguments to the executor service. We put the unique_ptr's + // into a vector to keep the data alive on the service until the end of this + // function. + std::vector> argument_data; + std::vector> ref_argument_data; + for (const auto& arg : arguments) { + TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg.Clone())); + TF_ASSIGN_OR_RETURN(auto ref_data, ref_client_->TransferToServer(arg)); + argument_data.push_back(std::move(data)); + ref_argument_data.push_back(std::move(ref_data)); + } + + // Create raw pointers to the GlobalData for the rest of the call stack. + std::vector argument_data_ptr; + std::transform( + argument_data.begin(), argument_data.end(), + std::back_inserter(argument_data_ptr), + [](const std::unique_ptr& data) { return data.get(); }); + std::vector ref_argument_data_ptr; + std::transform( + ref_argument_data.begin(), ref_argument_data.end(), + std::back_inserter(ref_argument_data_ptr), + [](const std::unique_ptr& data) { return data.get(); }); + + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + + TF_ASSIGN_OR_RETURN(auto result, + ExecuteAndTransfer(computation, argument_data_ptr)); + + TF_ASSIGN_OR_RETURN(auto reference, ExecuteAndTransferReference( + computation, ref_argument_data_ptr)); + + return std::make_pair(std::move(reference), std::move(result)); +} + Computation ClientLibraryTestBase::CreateScalarRelu() { ComputationBuilder builder(client_, "relu"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); @@ -522,33 +682,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - ComputationDataHandle ClientLibraryTestBase::AddParam( const Literal& argument, ComputationBuilder* builder) { ComputationDataHandle data_handle; @@ -557,10 +690,67 @@ ComputationDataHandle ClientLibraryTestBase::AddParam( return data_handle; } +XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, + XlaBuilder* builder) { + XlaOp data_handle; + arguments_.push_back(CreateParameterAndTransferLiteral( + arguments_.size(), argument, "", builder, &data_handle)); + return data_handle; +} + ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( const Literal& literal, ComputationBuilder* builder) { return builder->ConstantLiteral( use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, + XlaBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareLiteral( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + const Shape* shape_with_layout); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + ComputationBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + +template void ClientLibraryTestBase::ComputeAndCompareTuple( + XlaBuilder* builder, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + +template StatusOr> ClientLibraryTestBase::Execute( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + +template StatusOr> ClientLibraryTestBase::Execute( + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index ba0319990bc04196386e6812b0a03671676698ec..b3212dd2282375367ce890e960278fc469a5ef52 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -91,18 +92,36 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience methods for building and running a computation with the member // execution options. Modify execution_options_ in your test if you want to // customize the options. + template StatusOr> Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + BuilderT* builder, tensorflow::gtl::ArraySlice arguments); + + // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once + // the migration to XlaBuilder is complete. + + template StatusOr> ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, + BuilderT* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( const Computation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); + + // This executes the computation via the reference client (which connects a + // interpreter backend). The result is used as the expected values of the + // computation. + StatusOr> ExecuteAndTransferReference( + const XlaComputation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); + // Convenience OrDie variants of above methods. std::unique_ptr ExecuteOrDie( ComputationBuilder* builder, @@ -113,29 +132,31 @@ class ClientLibraryTestBase : public ::testing::Test { // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. + string ExecuteToString(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); string ExecuteToString(ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder for details). For each rank, two forms are provided: one - // for floating point types with an ErrorSpec parameter, and one for integral - // types without the ErrorSpec parameter. - template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + // ComputationBuilder/XlaBuilder for details). For each rank, two forms are + // provided: one for floating point types with an ErrorSpec parameter, and one + // for integral types without the ErrorSpec parameter. + template + void ComputeAndCompareR0(BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected, + template + void ComputeAndCompareR0(BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR1(ComputationBuilder* builder, + template + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR1(ComputationBuilder* builder, + template + void ComputeAndCompareR1(BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); @@ -146,55 +167,53 @@ class ClientLibraryTestBase : public ::testing::Test { const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(ComputationBuilder* builder, - const Array2D& expected, + template + void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(ComputationBuilder* builder, - const Array2D& expected, + template + void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR3(ComputationBuilder* builder, - const Array3D& expected, + template + void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR3(ComputationBuilder* builder, - const Array3D& expected, + template + void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR4(ComputationBuilder* builder, - const Array4D& expected, + template + void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR4(ComputationBuilder* builder, - const Array4D& expected, + template + void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. + template void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); + template void ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. + template tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); + template tensorflow::Status ComputeAndCompareLiteralWithStatus( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -206,11 +225,13 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. + template void ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments); + template void ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, + BuilderT* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result @@ -223,6 +244,14 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + // Convenience method for running a built computation and comparing the result + // with the reference result. + void ComputeAndCompare(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompare(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); + // Create scalar operations for use in reductions. Computation CreateScalarRelu(); Computation CreateScalarMax(); @@ -266,17 +295,19 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. + template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + BuilderT* builder, HandleT* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. + template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters @@ -285,18 +316,24 @@ class ClientLibraryTestBase : public ::testing::Test { // set exactly once. The first added parameter gets index 0, then 1 and so on. ComputationDataHandle AddParam(const Literal& argument, ComputationBuilder* builder); + XlaOp AddParam(const Literal& argument, XlaBuilder* builder); template ComputationDataHandle AddParam(const Array& argument, ComputationBuilder* builder) { return AddParam(*Literal::CreateFromArray(argument), builder); } + template + XlaOp AddParam(const Array& argument, XlaBuilder* builder) { + return AddParam(*Literal::CreateFromArray(argument), builder); + } // Creates a constant instruction with the given literal. When the // use_bfloat16 flag is set but the literal has F32 elements, the elements // will be converted to BF16s. ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, ComputationBuilder* builder); + XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be @@ -307,6 +344,12 @@ class ClientLibraryTestBase : public ::testing::Test { return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); } + template + XlaOp CreateConstantFromArray(const Array& array, + XlaBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + // Same as CreateConstantFromArray, but for scalars. template ComputationDataHandle CreateConstantFromScalar(NativeT value, @@ -315,6 +358,12 @@ class ClientLibraryTestBase : public ::testing::Test { builder); } + template + XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // @@ -323,10 +372,12 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template - std::unique_ptr CreateR0Parameter( - NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle); + template + std::unique_ptr CreateR0Parameter(NativeT value, + int64 parameter_number, + const string& name, + BuilderT* builder, + HandleT* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -336,11 +387,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -351,11 +401,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -366,11 +415,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle); + const string& name, BuilderT* builder, HandleT* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -381,6 +429,7 @@ class ClientLibraryTestBase : public ::testing::Test { PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } Client* client_; + Client* ref_client_; // To compute reference result. ExecutionOptions execution_options_; private: @@ -399,13 +448,32 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); + tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output); + tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& expected, + tensorflow::gtl::ArraySlice arguments, + const std::function& verify_output, + const Shape* output_with_layout = nullptr); + // Executes the computation and calculates the expected reference value using - // the HloEvaluator. Returns two literal in the order of (expected, actual). + // the HloEvaluator. Returns two literals in the order of (expected, actual). StatusOr, std::unique_ptr>> ComputeValueAndReference(ComputationBuilder* builder, const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice arguments); + // Executes the computation and calculates the expected reference value using + // the reference client. Returns two literals in the order of (expected, + // actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(XlaBuilder* builder, + tensorflow::gtl::ArraySlice arguments); + // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; @@ -414,9 +482,9 @@ class ClientLibraryTestBase : public ::testing::Test { std::vector> arguments_; }; -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -424,9 +492,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - ComputationBuilder* builder, NativeT expected, + BuilderT* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -440,9 +508,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -450,9 +518,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, + BuilderT* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -466,9 +534,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -476,9 +544,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - ComputationBuilder* builder, const Array2D& expected, + BuilderT* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -492,9 +560,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -502,9 +570,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - ComputationBuilder* builder, const Array3D& expected, + BuilderT* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -518,9 +586,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -528,9 +596,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - ComputationBuilder* builder, const Array4D& expected, + BuilderT* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -544,10 +612,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } -template +template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - ComputationBuilder* builder, ComputationDataHandle* data_handle) { + BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -558,11 +626,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -573,11 +640,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -588,11 +654,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, ComputationBuilder* builder, - ComputationDataHandle* data_handle) { + const string& name, BuilderT* builder, HandleT* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -628,6 +693,37 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } +template +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + BuilderT* builder, + HandleT* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +template +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, BuilderT* builder, + HandleT* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 045148cdd11da94ae4789a753efca95c6aaa1f27..32e2f2c0848407ec46a5ac52e2668ef27b92c426 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -109,14 +111,14 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { XLA_TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { - Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr const_arg, client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Add(b.Parameter(0, shape, "param_0"), b.ConstantR2({{1, 2}, {3, 4}})); TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); @@ -124,14 +126,14 @@ XLA_TEST_F(ClientTest, // We can't really test parallel execution on CPU since all of the cores in a // CPU are presented as a single device. So for now we test "parallel" // execution on a single device. - std::vector computation_instances; + std::vector computation_instances; TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); ASSERT_EQ(devices.size(), 1); ExecutionOptions options = execution_options_; *options.add_device_handles() = devices[0]; - computation_instances.push_back(Client::ComputationInstance( + computation_instances.push_back(Client::XlaComputationInstance( add_with_one_arg, {const_arg.get()}, options, nullptr)); TF_ASSERT_OK_AND_ASSIGN(auto results, diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index ec2c580670cfac14ba42e8c9a836c86551af4b89..c15d808f1ddfb44a512fa395bb8e515bca3859b6 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -31,6 +33,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -70,28 +74,35 @@ class ComputeConstantTest : public ::testing::Test { } StatusOr> ComputeConstantLiteral( - Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( - operand, output_layout, parameters)); + Client* client, const XlaOp& operand, XlaBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); + TF_ASSIGN_OR_RETURN(auto computed, + client->ComputeConstant(subgraph, output_layout)); return std::move(computed); } + template + StatusOr ComputeConstantScalar(Client* client, const XlaOp& operand, + XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, + builder, nullptr)); + return literal->Get({}); + } + template StatusOr ComputeConstantScalar( Client* client, const ComputationDataHandle& operand, ComputationBuilder* builder, tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN( - auto literal, - ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); + TF_ASSIGN_OR_RETURN(auto literal, + builder->ComputeConstant( + operand, /*output_layout=*/nullptr, parameters)); return literal->Get({}); } - bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder, int64 num_parameters = 0) { - StatusOr result = builder->IsConstant(operand, num_parameters); + bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { + StatusOr result = builder->IsConstant(operand); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } @@ -102,7 +113,7 @@ class ComputeConstantTest : public ::testing::Test { TEST_F(ComputeConstantTest, ScalarInt32Literal) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.ConstantR0(42); EXPECT_TRUE(IsConstant(computation, &b)); @@ -115,7 +126,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) { TEST_F(ComputeConstantTest, ScalarFloatAdd) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR0(42.5f), b.ConstantR0(1.5f)); EXPECT_TRUE(IsConstant(computation, &b)); @@ -129,7 +140,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) { TEST_F(ComputeConstantTest, ScalarRng) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.RngUniform(b.ConstantR0(1.1f), b.ConstantR0(2.1f), ShapeUtil::MakeShape(F32, {})); @@ -150,25 +161,27 @@ TEST_F(ComputeConstantTest, Param) { std::vector arguments; arguments.push_back(std::move(*Literal::CreateR0(42.5f))); - EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); - - auto value = - ComputeConstantScalar(client, computation, &b, arguments); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 44.0f); + TF_ASSERT_OK_AND_ASSIGN(bool is_constant, + b.IsConstant(computation, arguments.size())); + EXPECT_TRUE(is_constant); + + TF_ASSERT_OK_AND_ASSIGN( + auto value, + ComputeConstantScalar(client, computation, &b, arguments)); + EXPECT_EQ(value, 44.0f); } } TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on a parameter")) + EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), + "depends on a parameter")) << value.status(); } } @@ -176,15 +189,15 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR0(1.0f), b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on a parameter")) + EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), + "depends on a parameter")) << value.status(); } } @@ -194,7 +207,7 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { TEST_F(ComputeConstantTest, UnrelatedParam) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0"); auto constant_4 = @@ -211,64 +224,64 @@ TEST_F(ComputeConstantTest, UnrelatedParam) { EXPECT_TRUE(IsConstant(constant_13, &b)); - auto value = ComputeConstantScalar(client, constant_13, &b); - ASSERT_TRUE(value.ok()) << value.status(); - EXPECT_EQ(value.ValueOrDie(), 13.0f); + TF_ASSERT_OK_AND_ASSIGN( + auto value, ComputeConstantScalar(client, constant_13, &b)); + EXPECT_EQ(value, 13.0f); } } TEST_F(ComputeConstantTest, NonScalarAdd) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Add(b.ConstantR1({1, 2}), b.ConstantR1({3, 4})); EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(client, computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computed, + ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } TEST_F(ComputeConstantTest, IntegerDivide) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); auto computation = b.Div(b.ConstantR0(15), b.ConstantR0(3)); EXPECT_TRUE(IsConstant(computation, &b)); - auto computed = ComputeConstantLiteral(client, computation, &b); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computed, + ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } XLA_TEST_F(ComputeConstantTest, Layout) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); + XlaBuilder b(TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& layout : layouts) { auto layout_proto = LayoutUtil::MakeLayout(layout); - auto computed = ComputeConstantLiteral( - client, - b.Add(b.ConstantR2({{1, 2}, {3, 4}}), - b.ConstantR2({{10, 20}, {30, 40}})), - &b, &layout_proto); - ASSERT_TRUE(computed.ok()) << computed.status(); + TF_ASSERT_OK_AND_ASSIGN( + auto computed, ComputeConstantLiteral( + client, + b.Add(b.ConstantR2({{1, 2}, {3, 4}}), + b.ConstantR2({{10, 20}, {30, 40}})), + &b, &layout_proto)); std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), + computed->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index 1bcad5a3f37a37c9d482f3a5a899ac527666cca3..a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -38,9 +38,9 @@ using ::testing::HasSubstr; // Concatenate expects at least one argument. XLA_TEST_F(ConcatTest, Concat_Nothing) { - ComputationBuilder builder(client_, TestName()); - auto concatenated = builder.ConcatInDim({}, 0); - StatusOr computation_status = builder.Build(); + XlaBuilder builder(TestName()); + builder.ConcatInDim({}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), HasSubstr("Concatenate expects at least one argument")); @@ -48,18 +48,18 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) { // Concatenate with one argument works. XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a}, 0); + builder.ConcatInDim({a}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -68,51 +68,51 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) { // Show that we can't concatenate R0 with R0 because we can't name the dimension // to concatenate on. XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(42.0); auto b = builder.ConstantR0(64.0); - auto concatenated = builder.ConcatInDim({a, b}, 0); - StatusOr computation_status = builder.Build(); + builder.ConcatInDim({a, b}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), - HasSubstr("dimension to concatenate along out of bounds: 0")); + HasSubstr("out of bounds: 0")); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); auto b = builder.ConstantR1({}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0, 64.0}); auto b = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -129,20 +129,20 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) { expected[253 + i] = rhs[i] = 253 + i + 1; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1(lhs); auto b = builder.ConstantR1(rhs); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { for (int dim : {0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(Array2D(0, 0)); auto b = builder.ConstantR2FromArray2D(Array2D(0, 0)); - auto concatenated = builder.ConcatInDim({a, b}, dim); + builder.ConcatInDim({a, b}, dim); ComputeAndCompareR2(&builder, Array2D(0, 0), {}, ErrorSpec(0.0001)); @@ -150,26 +150,27 @@ XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) { } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D expected({ - {0}, {64}, + {0}, + {64}, }); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(1, 1); auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D expected({ {0, 64}, @@ -178,22 +179,22 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) { } XLA_TEST_F(ConcatTest, Concat2x0With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); ComputeAndCompareR2(&builder, *b_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat2x3With2x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(2, 3); auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 1); + builder.ConcatInDim({a, b}, 1); Array2D expected({ {0, 1, 2, 64, 65, 66, 67, 68}, @@ -203,22 +204,22 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) { } XLA_TEST_F(ConcatTest, Concat3x2With0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); ComputeAndCompareR2(&builder, *a_array, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat3x2With5x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a_array = CreatePatternedMatrix(3, 2); auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0); auto a = builder.ConstantR2FromArray2D(*a_array); auto b = builder.ConstantR2FromArray2D(*b_array); - auto concatenated = builder.ConcatInDim({a, b}, 0); + builder.ConcatInDim({a, b}, 0); Array2D expected({ {0, 1}, @@ -234,16 +235,16 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) { } XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR3FromArray3D(Array3D(3, 0, 2)); auto b = builder.ConstantR3FromArray3D(Array3D(3, 0, 1)); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); ComputeAndCompareR3(&builder, Array3D(3, 0, 3), {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_array({ // 3x1x2 {{0, 1}}, @@ -258,27 +259,29 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) { }); auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); - auto concatenated = builder.ConcatInDim({a, b}, 2); + builder.ConcatInDim({a, b}, 2); Array3D expected({ - {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}}, + {{0, 1, 6}}, + {{2, 3, 7}}, + {{4, 5, 8}}, }); ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); - auto concatenated = builder.ConcatInDim({a, b, c}, 0); + builder.ConcatInDim({a, b, c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D a_array({ // 3x1x2 {{0, 1}}, @@ -300,35 +303,35 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) { auto a = builder.ConstantR3FromArray3D(a_array); auto b = builder.ConstantR3FromArray3D(b_array); auto c = builder.ConstantR3FromArray3D(c_array); - auto concatenated = builder.ConcatInDim({a, b, c}, 2); + builder.ConcatInDim({a, b, c}, 2); Array3D expected({ - {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}}, + {{0, 1, 2, 3}}, + {{4, 5, 6, 7}}, + {{8, 9, 10, 11}}, }); ComputeAndCompareR3(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); // concatenated = (a concat b) concat c - auto concatenated = - builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); + builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0}); auto b = builder.ConstantR1({64.0}); auto c = builder.ConstantR1({256.0}); // concatenated = a concat (b concat c) - auto concatenated = - builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); + builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0); std::vector expected = {42, 64, 256}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -342,7 +345,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 0); @@ -363,7 +366,7 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) { rhs(0, i) = i + 1024; } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 1); @@ -388,7 +391,7 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2D(lhs); auto b = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a, b}, 1); @@ -404,13 +407,13 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) { // Show that we can't concatenate with an opaques. XLA_TEST_F(ConcatTest, CannotConcatOpaques) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto opaque_shape = ShapeUtil::MakeOpaqueShape(); auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); auto x = builder.Parameter(0, r1f32, "x"); auto y = builder.Parameter(1, opaque_shape, "y"); - auto concatenated = builder.ConcatInDim({x, y}, 0); - StatusOr computation_status = builder.Build(); + builder.ConcatInDim({x, y}, 0); + StatusOr computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), @@ -418,23 +421,23 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto p0 = builder.ConstantR1({true}); auto p1 = builder.ConstantR1({false}); auto p2 = builder.ConstantR1({true}); - auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0); + builder.ConcatInDim({p0, p1, p2}, 0); bool expected[] = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR1({1}); auto a1 = builder.ConstantR1({2, 3}); auto a2 = builder.ConstantR1({4, 5, 6}); auto a3 = builder.ConstantR1({7, 8, 9, 10}); - auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0); + builder.ConcatInDim({a0, a1, a2, a3}, 0); std::vector expected(10); std::iota(expected.begin(), expected.end(), 1); @@ -442,7 +445,7 @@ XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) { } XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D arr0(9, 17, 1); arr0.Fill(1); @@ -462,14 +465,14 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) { } } - ComputationDataHandle h0; + XlaOp h0; auto p0 = CreateR3Parameter(arr0, /*parameter_number=*/0, "p0", &builder, &h0); - ComputationDataHandle h1; + XlaOp h1; auto p1 = CreateR3Parameter(arr1, /*parameter_number=*/1, "p1", &builder, &h1); - auto concatenated = builder.ConcatInDim({h0, h1}, 2); + builder.ConcatInDim({h0, h1}, 2); ComputeAndCompareR3(&builder, expected, {p0.get(), p1.get()}); } @@ -495,7 +498,7 @@ TEST_P(ConcatR2BinaryTest, DoIt) { Array2D rhs(spec.rhs_dim0, spec.rhs_dim1); rhs.FillUnique(1000); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a0 = builder.ConstantR2FromArray2D(lhs); auto a1 = builder.ConstantR2FromArray2D(rhs); builder.ConcatInDim({a0, a1}, spec.concat_dimension); @@ -521,7 +524,7 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, f32_scalar, "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto mul = builder.Mul(x, y); @@ -545,7 +548,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "z"); @@ -573,7 +576,7 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, x_literal->shape(), "x"); auto y = builder.Parameter(1, f32_scalar, "y"); auto z = builder.Parameter(2, f32_scalar, "y"); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index bc821674820fb128823786d7149037fc59b22ab6..7ff6706935740c7d76ee5cd03eae292386760397 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -23,8 +24,8 @@ namespace { class ConditionalOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0ConstantComputation(float value) { - ComputationBuilder builder(client_, "Constant"); + XlaComputation CreateR0ConstantComputation(float value) { + XlaBuilder builder("Constant"); builder.Parameter(0, empty_tuple_, "tuple"); builder.ConstantR0(value); auto build_status = builder.Build(); @@ -32,16 +33,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0IdentityComputation() { - ComputationBuilder builder(client_, "Identity"); + XlaComputation CreateR0IdentityComputation() { + XlaBuilder builder("Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateCeilComputation(const Shape& shape) { - ComputationBuilder builder(client_, "Ceil"); + XlaComputation CreateCeilComputation(const Shape& shape) { + XlaBuilder builder("Ceil"); auto param = builder.Parameter(0, shape, "param"); builder.Ceil(param); auto build_status = builder.Build(); @@ -49,16 +50,16 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0CeilComputation() { + XlaComputation CreateR0CeilComputation() { return CreateCeilComputation(r0f32_); } - Computation CreateR1CeilComputation() { + XlaComputation CreateR1CeilComputation() { return CreateCeilComputation(r1s2f32_); } - Computation CreateFloorComputation(const Shape& shape) { - ComputationBuilder builder(client_, "Floor"); + XlaComputation CreateFloorComputation(const Shape& shape) { + XlaBuilder builder("Floor"); auto param = builder.Parameter(0, shape, "param"); builder.Floor(param); auto build_status = builder.Build(); @@ -66,17 +67,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0FloorComputation() { + XlaComputation CreateR0FloorComputation() { return CreateFloorComputation(r0f32_); } - Computation CreateR1FloorComputation() { + XlaComputation CreateR1FloorComputation() { return CreateFloorComputation(r1s2f32_); } - Computation CreateTupleCeilComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleCeilComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -88,17 +89,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleCeilComputation() { + XlaComputation CreateR0TupleCeilComputation() { return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_); } - Computation CreateR1TupleCeilComputation() { + XlaComputation CreateR1TupleCeilComputation() { return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_); } - Computation CreateTupleFloorComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleFloorComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -110,17 +111,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleFloorComputation() { + XlaComputation CreateR0TupleFloorComputation() { return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_); } - Computation CreateR1TupleFloorComputation() { + XlaComputation CreateR1TupleFloorComputation() { return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_); } - Computation CreateTupleAddComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleAddComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -130,17 +131,17 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleAddComputation() { + XlaComputation CreateR0TupleAddComputation() { return CreateTupleAddComputation("AddR0", tuple_2_r0f32_); } - Computation CreateR1TupleAddComputation() { + XlaComputation CreateR1TupleAddComputation() { return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_); } - Computation CreateTupleSubComputation(const string& computation_name, - const Shape& tuple_shape) { - ComputationBuilder builder(client_, computation_name); + XlaComputation CreateTupleSubComputation(const string& computation_name, + const Shape& tuple_shape) { + XlaBuilder builder(computation_name); auto tuple = builder.Parameter(0, tuple_shape, "tuple"); auto x = builder.GetTupleElement(tuple, 0); auto y = builder.GetTupleElement(tuple, 1); @@ -150,11 +151,11 @@ class ConditionalOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0TupleSubComputation() { + XlaComputation CreateR0TupleSubComputation() { return CreateTupleSubComputation("SubR0", tuple_2_r0f32_); } - Computation CreateR1TupleSubComputation() { + XlaComputation CreateR1TupleSubComputation() { return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_); } @@ -170,26 +171,25 @@ class ConditionalOpTest : public ClientLibraryTestBase { // Test true and false computations that do not take any parameters. XLA_TEST_F(ConditionalOpTest, Parameters0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({}); auto true_computation = CreateR0ConstantComputation(56.0f); auto false_computation = CreateR0ConstantComputation(12.0f); - auto result = builder.Conditional(pred, operands, true_computation, operands, - false_computation); + builder.Conditional(pred, operands, true_computation, operands, + false_computation); ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); } // Test true and false computations that take in 1 parameter. XLA_TEST_F(ConditionalOpTest, Parameters1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); auto identity = CreateR0IdentityComputation(); - auto result = - builder.Conditional(pred, operand1, identity, operand2, identity); + builder.Conditional(pred, operand1, identity, operand2, identity); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -197,12 +197,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) { // Test conditional with two different computations in the true and false cases // that take in different arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); - auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -210,11 +210,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { // Test conditional with two different computations in the true and false cases // that take in the same arguments. XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand = builder.ConstantR0(12.6f); - auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(), - operand, CreateR0FloorComputation()); + builder.Conditional(pred, operand, CreateR0CeilComputation(), operand, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -222,12 +222,12 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { // Test conditional with the same computation in the true and false cases but // take in different arguments. XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); auto floor = CreateR0FloorComputation(); - auto result = builder.Conditional(pred, operand1, floor, operand2, floor); + builder.Conditional(pred, operand1, floor, operand2, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -235,11 +235,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { // Test conditional with the same computation in the true and false cases that // take in the same arguments. XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand = builder.ConstantR0(12.6f); auto floor = CreateR0FloorComputation(); - auto result = builder.Conditional(pred, operand, floor, operand, floor); + builder.Conditional(pred, operand, floor, operand, floor); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -247,12 +247,12 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { // Test conditional with different instances of the same computation in the true // and false cases. XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); - auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); } @@ -260,7 +260,7 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { // Test the case when a call invokes a computation that contains a conditional. XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); @@ -268,7 +268,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { false_operand, CreateR0FloorComputation()); auto inner_builder_result = inner_builder.Build(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.4f); auto operand2 = builder.ConstantR0(12.6f); @@ -281,14 +281,13 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { // Test true and false computations that take in 2 parameters and predicate is // true. XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), - operands, CreateR0TupleSubComputation()); + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); } @@ -296,14 +295,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { // Test true and false computations that take in 2 parameters and predicate is // false. XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR0TupleAddComputation(), - operands, CreateR0TupleSubComputation()); + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands, + CreateR0TupleSubComputation()); ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); } @@ -311,14 +309,13 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { // Test true and false computations that take in 2 array parameters and // predicate is true. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operand1 = builder.ConstantR1({24.0f, 56.0f}); auto operand2 = builder.ConstantR1({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), - operands, CreateR1TupleSubComputation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); } @@ -326,21 +323,20 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { // Test true and false computations that take in 2 array parameters and // predicate is false. XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operand1 = builder.ConstantR1({24.0f, 56.0f}); auto operand2 = builder.ConstantR1({10.0f, 11.0f}); auto operands = builder.Tuple({operand1, operand2}); - auto result = - builder.Conditional(pred, operands, CreateR1TupleAddComputation(), - operands, CreateR1TupleSubComputation()); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR1TupleSubComputation()); ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); } // Test true and false computations that return a tuple of scalars. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operands = builder.Tuple( {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); @@ -356,7 +352,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { // Test true and false computations that return a tuple of arrays. XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), builder.ConstantR1({25.6f, 29.2f})}); @@ -373,7 +369,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { // Test true and false computations that return a tuple of a predicate, a // scalar, and an array. XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { - ComputationBuilder true_builder(client_, TestName() + ".true"); + XlaBuilder true_builder(TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); auto true_pred = true_builder.ConstantR0(true); @@ -384,7 +380,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); - ComputationBuilder false_builder(client_, TestName() + ".false"); + XlaBuilder false_builder(TestName() + ".false"); { false_builder.Parameter(0, empty_tuple_, "tuple"); auto false_pred = false_builder.ConstantR0(false); @@ -395,7 +391,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operands = builder.Tuple({}); builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), @@ -411,7 +407,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { // Test true and false computations that return a nested tuple. XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { - ComputationBuilder true_builder(client_, TestName() + ".true"); + XlaBuilder true_builder(TestName() + ".true"); { true_builder.Parameter(0, empty_tuple_, "tuple"); auto true_constant1 = true_builder.ConstantR0(12.2f); @@ -424,7 +420,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { auto true_builder_result = true_builder.Build(); EXPECT_IS_OK(true_builder_result.status()); - ComputationBuilder false_builder(client_, TestName() + ".false"); + XlaBuilder false_builder(TestName() + ".false"); { false_builder.Parameter(0, empty_tuple_, "tuple"); auto false_constant1 = false_builder.ConstantR0(46.6f); @@ -438,7 +434,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { auto false_builder_result = false_builder.Build(); EXPECT_IS_OK(false_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto operands = builder.Tuple({}); builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), @@ -460,16 +456,16 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { // params. XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle pred, operand1, operand2; + XlaOp pred, operand1, operand2; auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); auto operand1_param = CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); - auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), - operand2, CreateR0FloorComputation()); + builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2, + CreateR0FloorComputation()); ComputeAndCompareR0( &builder, 57.0f, @@ -480,16 +476,16 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { // Test conditional that takes in array operands in the form of external params. XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle pred, operand1, operand2; + XlaOp pred, operand1, operand2; auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); auto operand1_param = CreateR1Parameter({24.3f, 56.7f}, 1, "operand1", &builder, &operand1); auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", &builder, &operand2); - auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(), - operand2, CreateR1FloorComputation()); + builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2, + CreateR1FloorComputation()); ComputeAndCompareR1( &builder, {10.0f, 11.0f}, @@ -499,7 +495,7 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { // Test the case where one conditional is nested within another. XLA_TEST_F(ConditionalOpTest, NestedConditionals) { - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); @@ -514,7 +510,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred1 = builder.ConstantR0(true); auto pred2 = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(1.1f); @@ -529,7 +525,7 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) { } XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { - ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + XlaBuilder inner_builder(TestName() + ".inner_conditional"); { Shape r0bool = ShapeUtil::MakeShape(PRED, {}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); @@ -544,7 +540,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { auto inner_builder_result = inner_builder.Build(); EXPECT_IS_OK(inner_builder_result.status()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred2 = builder.ConstantR0(false); auto operand1 = builder.ConstantR0(1.1f); auto operand2 = builder.ConstantR0(12.2f); @@ -556,7 +552,7 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { // Test a mismatch in the shape of the true operand and true computation. XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto operand1 = builder.ConstantR0(56.0f); auto operand2 = builder.ConstantR0(12.0f); @@ -571,5 +567,56 @@ XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { "only parameter of true_computation")); } +XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_}); + XlaComputation swapper; + { + XlaBuilder builder(TestName() + ".swapper"); + auto param0 = builder.Parameter(0, tuple_shape, "sp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({y, x}); + swapper = builder.Build().ConsumeValueOrDie(); + } + XlaComputation forwarder; + { + XlaBuilder builder(TestName() + ".forwarder"); + auto param0 = builder.Parameter(0, tuple_shape, "fp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + builder.Tuple({x, y}); + forwarder = builder.Build().ConsumeValueOrDie(); + } + XlaComputation main; + { + XlaBuilder builder(TestName() + ".main"); + auto param0 = builder.Parameter(0, tuple_shape, "mp0"); + auto x = builder.GetTupleElement(param0, 0); + auto y = builder.GetTupleElement(param0, 1); + auto lt_pred = builder.Lt(x, y); + auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper); + auto ge_pred = builder.Ge(x, y); + builder.Conditional(ge_pred, res, swapper, res, forwarder); + main = builder.Build().ConsumeValueOrDie(); + } + + auto test_swap = [&](float a, float b) { + XlaBuilder builder(TestName()); + auto x = builder.ConstantR0(a); + auto y = builder.ConstantR0(b); + auto tuple_operand = builder.Tuple({x, y}); + builder.Call(main, {tuple_operand}); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(a).get(), + Literal::CreateR0(b).get()}), + {}, error_spec_); + }; + + test_swap(3.11f, 9.4f); + test_swap(11.24f, 5.55f); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index f66e3b57bf45fbc9f8ea786146d6fffe5d55a262..0842a8918bcfec037ab0f9aa24014c7d8296cdf8 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -106,11 +108,163 @@ TEST_F(ConvertTest, ConvertR1F32ToR1S32) { XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, F32); + std::vector arg{ + -9223371216516022272, + -2, + -1, + -0x7FFFFFFF, + -0x80000000, + 0, + 1, + 2, + 1073742145, + 1073742656, + 0x7FFFFFFF, + 0x80000000, + 826720496944058148, + 4296062029846194332, + 0x0007FB72E4000000LL, + 0x0007FB72E4000001LL, + 0x0007FB72E6000000LL, + 0x0007FB72E7000000LL, + 0x0007FB72E7FFFFFFLL, + 0x0007FB72E8000000LL, + 0x0007FB72E8000001LL, + 0x0007FB72EA000000LL, + 0x0007FB72EB000000LL, + 0x0007FB72EBFFFFFFLL, + 0x0007FB72EC000000LL, + 0x7FFFFF0000000000LL, + 0x7FFFFF8000000000LL, + 0x7FFFFFFFFFFFFF00, + static_cast(0xFFFFFFFFFFFFFFFF), + static_cast(0x0000f234e67e0001LL), + static_cast(0x8000000000000000), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + static_cast(0x8000008000000000LL), + static_cast(0x8000010000000000LL), + }; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, F32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} - std::vector expected = {32.0, 64.0}; - ComputeAndCompareR1(&builder, expected, {}); +XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0, 1, 0x1000, 0x7fffffff, + 0x80000000, 0x80000001, 0x80000002, 0x80000003, + 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, F32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0.0f, 1.0f, 16777216.0f, + 16777218.0f, 2147483647.0f, 4294967040.0f}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, U32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0, 1, 0x1000, -1, -0x1000}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + // Test cases from compiler_rt library. + std::vector arg{0.0f, + 0.5f, + 0.99f, + 1.0f, + 1.5f, + 1.99f, + 2.0f, + 2.01f, + 2147483648.f, + -0.5f, + -0.99f, + -1.0f, + -1.5f, + -1.99f, + -2.0f, + -2.01f, + 0x1.FFFFFEp+62F, + 0x1.FFFFFCp+62F, + -0x1.FFFFFEp+62F, + -0x1.FFFFFCp+62F}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { @@ -208,5 +362,104 @@ TEST_F(ConvertTest, ConvertReshape) { ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } +std::vector GetInterestingF16ConversionTestCases() { + float infinity = std::numeric_limits::infinity(); + float half_min_positive_normal = + tensorflow::bit_cast(0x38800000); + float half_max_subnormal = tensorflow::bit_cast(0x387fc000); + float half_min_positive_subnormal = + tensorflow::bit_cast(0x33800000); + float half_max = 65504.0f; + + std::vector test_cases( + {-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f, + -half_min_positive_subnormal, -half_max_subnormal, + -half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal, + half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max, + (half_max * 2 + 1), infinity}); + return test_cases; +} + +XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { + std::vector test_cases = GetInterestingF16ConversionTestCases(); + std::vector input; + c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); + std::vector expected_output; + c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast(h); }); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dot_lhs_handle, + client_->TransferToServer(*Literal::CreateR1(input))); + + ComputationBuilder builder(client_, TestName()); + builder.ConvertElementType( + builder.Parameter( + 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), + "param"), + F32); + + ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { + std::vector input = GetInterestingF16ConversionTestCases(); + std::vector expected_output; + c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dot_lhs_handle, + client_->TransferToServer(*Literal::CreateR1(input))); + + ComputationBuilder builder(client_, TestName()); + builder.ConvertElementType( + builder.Parameter( + 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), + "param"), + F16); + + ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertC64ToC64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{42.0f, 64.0f}}; + builder.ConvertElementType(builder.ConstantR1(x), C64); + ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(ConvertTest, ConvertS64S64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{-42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), S64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64U64) { + ComputationBuilder builder(client_, TestName()); + std::vector x = {{42, 64}}; + builder.ConvertElementType(builder.ConstantR1(x), U64); + ComputeAndCompareR1(&builder, x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertU64S64) { + ComputationBuilder builder(client_, TestName()); + std::vector unsigned_x = {{42, UINT64_MAX}}; + builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); + std::vector signed_x = {{42, -1}}; + ComputeAndCompareR1(&builder, signed_x, {}); +} + +XLA_TEST_F(ConvertTest, ConvertS64U64) { + ComputationBuilder builder(client_, TestName()); + std::vector signed_x = {{42, -1, INT64_MIN}}; + builder.ConvertElementType(builder.ConstantR1(signed_x), U64); + std::vector unsigned_x = { + {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; + ComputeAndCompareR1(&builder, unsigned_x, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 1385b437fc47fe5289c401581fab8b5278872382..947959beb144e1509a77ad2f94b8493de46ba6f2 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -53,27 +53,12 @@ class ConvolutionTest : public ClientLibraryTestBase { #endif }; -// TODO(b/72509305): Enable half data type tests for CPU -#if (XLA_TEST_BACKEND_GPU) -using TestTypes = ::testing::Types; -#else +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; #endif -template -Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions); - -template <> -Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions) { - return ShapeUtil::MakeShape(F32, dimensions); -} - -template <> -Shape MakeShapeWrapper( - tensorflow::gtl::ArraySlice dimensions) { - return ShapeUtil::MakeShape(F16, dimensions); -} - template class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { public: @@ -103,12 +88,12 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { ASSERT_EQ(2, arhs->width()); ASSERT_EQ(2, arhs->height()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR4FromArray4D(*alhs); auto rhs = builder.ConstantR4FromArray4D(*arhs); - auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - ComputeAndCompare(&builder, conv, {}, error_spec_); + ComputeAndCompare(&builder, {}, error_spec_); } }; @@ -121,12 +106,12 @@ template class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 1, 2}); - Shape filter_shape = MakeShapeWrapper({1, 1, 1, 2}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -137,7 +122,7 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { {5.0f, 6.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -152,12 +137,12 @@ template class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -171,7 +156,7 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {5.0f, 6.0f}, {7.0f, 8.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -186,12 +171,12 @@ template class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + builder.Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({ @@ -206,7 +191,7 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { {7.0f, 8.0f}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -222,12 +207,12 @@ template class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); - Shape filter_shape = MakeShapeWrapper({1, 1, 3, 3}); + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 1, 3, 3}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + builder.Conv(input, filter, {1, 1}, Padding::kSame); Array4D input_data(1, 1, 4, 4); input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, @@ -238,7 +223,7 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { filter_data.FillWithYX(Array2D( {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); @@ -249,7 +234,7 @@ TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes); TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -279,10 +264,10 @@ template class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { - Shape input_shape = MakeShapeWrapper({1, 2, 5}); - Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. @@ -315,7 +300,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -346,7 +331,7 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); @@ -380,10 +365,10 @@ template class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { - Shape input_shape = MakeShapeWrapper({1, 2, 5}); - Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); // Convolution dimensions are bf0_oi0->bo0. @@ -417,7 +402,7 @@ TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_dims = {1, 4, 2, 3, 3}; std::vector filter_dims = {2, 2, 2, 3, 3}; Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); @@ -484,11 +469,11 @@ template class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { public: void RunTest() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_dims = {1, 3, 3, 5}; std::vector filter_dims = {3, 3, 5, 3}; - Shape input_shape = MakeShapeWrapper(input_dims); - Shape filter_shape = MakeShapeWrapper(filter_dims); + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); @@ -552,7 +537,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "convolution-canonicalization"); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); @@ -566,8 +551,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, dnums.set_kernel_output_feature_dimension(1); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); - auto conv = builder.ConvWithGeneralDimensions(input, filter, {}, - Padding::kValid, dnums); + builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums); Array2D param0(4, 29); param0.FillUnique(); @@ -578,7 +562,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Array2D expected_result(29, 10); expected_result.Fill(0); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(param0)), std::move(*Literal::CreateFromArray(param1))}, error_spec_); @@ -602,7 +586,7 @@ class Convolve1D1WindowTestBase protected: template void TestImpl() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); int64 input_feature = GetParam().input_feature; int64 output_feature = GetParam().output_feature; int64 batch = GetParam().batch; @@ -612,8 +596,8 @@ class Convolve1D1WindowTestBase input_feature}; std::vector filter_dims = {window_size, input_feature, output_feature}; - Shape input_shape = MakeShapeWrapper(input_dims); - Shape filter_shape = MakeShapeWrapper(filter_dims); + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); @@ -699,9 +683,7 @@ INSTANTIATE_TEST_CASE_P( #if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {}; -// TODO(b/72509305): Enable half data type tests for CPU. -XLA_TEST_P(Convolve1D1WindowTestHalf, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(Convolve1D1Window))) { +XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) { TestImpl(); } @@ -719,14 +701,16 @@ INSTANTIATE_TEST_CASE_P( Convolve1DTestParam{130, 1, 1, 1, 3}, Convolve1DTestParam{64, 1, 1, 1, 1}, Convolve1DTestParam{128, 1, 1, 1, 1}, - // TODO(b/72566306): the following three tests fail on CPU - // backend due to result miscompare. +// TODO(b/72566306): The following five tests failed on CPU with unreasonable +// relative errors. Last ran on 2018-02-22. +#if XLA_TEST_BACKEND_GPU Convolve1DTestParam{139, 1, 1, 128, 1}, Convolve1DTestParam{640, 3, 3, 128, 1}, Convolve1DTestParam{900, 1, 1, 10, 1}, Convolve1DTestParam{1, 10, 10, 1, 10}, - Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 10, 130, 1, 1}, +#endif + Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 64, 64, 1, 10}, Convolve1DTestParam{1, 65, 65, 1, 1}, Convolve1DTestParam{1, 128, 128, 1, 1}, @@ -738,13 +722,13 @@ INSTANTIATE_TEST_CASE_P( ); #endif -TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { - ComputationBuilder builder(client_, TestName()); +XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { + XlaBuilder builder(TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + builder.Conv(input, filter, {1, 1}, Padding::kValid); Array4D input_data(1, 1, 1, 2); input_data.FillWithYX(Array2D({ @@ -755,11 +739,34 @@ TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { {bfloat16(5), bfloat16(6)}, })); - ComputeAndCompare(&builder, conv, + ComputeAndCompare(&builder, {std::move(*Literal::CreateFromArray(input_data)), std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } +// Check that GPU convs still work if the CudnnAlgorithmPicker pass is disabled. +// (We run this test on all platforms, because, what the heck.) +XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "cudnn-convolution-algorithm-picker"); + + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillIota(0); + Array4D filter_data(1, 1, 1, 2); + filter_data.FillIota(10); + + ComputeAndCompare(&builder, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 2d847a66b0ae7c8f09fa0cb181a4c84ea99be5b1..b43d5c9ff5d75ee0e1b3c9ceb2bc295e631ac107 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -134,9 +134,9 @@ class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these // are reserved for internal use. XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { - ComputationBuilder builder(client_, TestName()); - auto call = builder.CustomCall("$illegal", /*operands=*/{}, - ShapeUtil::MakeShape(F32, {1})); + XlaBuilder builder(TestName()); + builder.CustomCall("$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); StatusOr> result = Execute(&builder, /*arguments=*/{}); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 032c06cd3c9f872f57674d3d7b5adc201c91ea77..3ab0ea4ad48c00724d48e7d285ec024e10d5db31 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -195,7 +195,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { auto result_status = client_->DeconstructTuple(*global_data); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("deconstructing nested tuples not yet supported")); + HasSubstr("Deconstructing nested tuples is not implemented")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6b0c04c2c083bbfce267dd92d24ef15c06186d26..c4031dfee593a13af6a5db15e43ed7bc418603c5 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -34,169 +34,220 @@ limitations under the License. namespace xla { namespace { -// TODO(b/34468543): use GUnit typed tests when we can do all tests on all -// backends. class DotOperationTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 1e-5}; - - protected: - template - void TestOneElementVectorDot(); - template - void TestVectorDot(); - template - void TestSquareMatrixDot(bool lhs_row_major = false, - bool rhs_row_major = false); - template - void TestNonsquareMatrixDot(bool lhs_row_major = false, - bool rhs_row_major = false); }; -XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { +#if defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = ::testing::Types; +#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = + ::testing::Types; +#elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX) +using TypesF16F32 = ::testing::Types; +using TypesF16F32F64 = ::testing::Types; +using TypesF16F32F64CF64 = + ::testing::Types; +#else +#error "Situation not handled yet" +#endif + +// Check that we can safely pass an input tuple's elements to a dot operation. +TEST_F(DotOperationTest, DotOfInputTupleElem) { ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({}); - auto rhs = builder.ConstantR1({}); - auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 0.0, {}, error_spec_); + ComputationDataHandle param; + auto param_data = CreateParameterAndTransferLiteral( + 0, + *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), + Literal::CreateR2({{5, 6}, {7, 8}}).get()}), + "arg0", &builder, ¶m); + auto lhs = builder.GetTupleElement(param, 0); + auto rhs = builder.GetTupleElement(param, 1); + builder.Dot(lhs, rhs); + + ComputeAndCompareLiteral(&builder, + *Literal::CreateR2({{19, 22}, {43, 50}}), + {param_data.get()}); } -XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2({{3.0, 4.0}}); - auto rhs = builder.ConstantR1({3.0, 4.0}); - auto result = builder.Dot(lhs, rhs); +template +class DotOperationTest_F16F32F64CF64 : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); - ComputeAndCompareR1(&builder, {25.0}, {}, error_spec_); -} +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); -template -void DotOperationTest::TestOneElementVectorDot() { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({2.0}); - auto rhs = builder.ConstantR1({3.0}); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 6.0, {}, error_spec_); + this->template ComputeAndCompareR0(&builder, static_cast(0.0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) { - TestOneElementVectorDot(); -} +template +class DotOperationTest_F16F32F64 : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); -XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) { - TestOneElementVectorDot(); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); + auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); + auto result = builder.Dot(lhs, rhs); + + this->template ComputeAndCompareR1(&builder, {static_cast(25.0f)}, {}, + this->error_spec_); } -template -void DotOperationTest::TestVectorDot() { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR1({1.0, 2.5, 42.0}); - auto rhs = builder.ConstantR1({11.0, -1.0, 0.5}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR1({static_cast(2.0f)}); + auto rhs = builder.ConstantR1({static_cast(3.0f)}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR0(&builder, 29.5, {}, error_spec_); + this->template ComputeAndCompareR0(&builder, static_cast(6.0f), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot(); } - -XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot(); } +XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); + auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); + auto result = builder.Dot(lhs, rhs); -namespace { + this->template ComputeAndCompareR0(&builder, static_cast(29.5f), {}, + this->error_spec_); +} std::vector MinorToMajorForIsRowMajor(bool row_major) { return {row_major ? 1 : 0, row_major ? 0 : 1}; } -} // namespace - -XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(0, 0), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(0, 0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto rhs = builder.ConstantR2({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); + auto rhs = builder.ConstantR2FromArray2D( + {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(0, 3), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(0, 3), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) { - ComputationBuilder builder(client_, TestName()); - auto lhs = - builder.ConstantR2({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}}); - auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D( + {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); + auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, error_spec_); + this->template ComputeAndCompareR2(&builder, Array2D(3, 0), {}, + this->error_spec_); } -XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) { - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); - auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); + auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto result = builder.Dot(lhs, rhs); - ComputeAndCompareR2(&builder, Array2D(2, 2, 0.0f), {}, - error_spec_); + this->template ComputeAndCompareR2( + &builder, Array2D(2, 2, static_cast(0.0f)), {}, this->error_spec_); } -XLA_TEST_F(DotOperationTest, FusedDot) { - ComputationBuilder builder(client_, TestName()); - auto param0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 4}), "arg0"); - auto param1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4, 1}), "arg1"); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto param0 = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); + auto param1 = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({4, 1}), "arg1"); auto exp0 = builder.Exp(param0); auto result = builder.Dot(exp0, param1); - auto lhs_handle = client_ - ->TransferToServer(*Literal::CreateR2( - {{1.0, 2.0, 3.0, 4.0}, {-1.0, -2.0, -3.0, -4.0}})) - .ConsumeValueOrDie(); - auto rhs_handle = client_ - ->TransferToServer(*Literal::CreateR2( - {{1.0}, {2.0}, {3.0}, {4.0}})) - .ConsumeValueOrDie(); - - ComputeAndCompareR2( - &builder, Array2D({{296.14560492846033}, {0.8611737683031964}}), - {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - -template -void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, - bool rhs_row_major) { auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 2.0}, {3.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 6.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2D( + {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); + auto rhs_handle = this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2D( + {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) + .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + if (std::is_same::value) { + this->error_spec_ = ErrorSpec{0.0001, 1e-3}; + } - Array2D expected({{15.0, -2.0}, {-25.0, 34.0}}); - ComputeAndCompareR2( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); + this->template ComputeAndCompareR2( + &builder, Array2D({{296.14560492846033f}, {0.8611737683031964f}}), + {lhs_handle.get(), rhs_handle.get()}, this->error_spec_); } +template +class SquareMatrixDot : public DotOperationTest { + public: + void TestImpl(bool lhs_row_major, bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 2.0f}, {3.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(lhs_row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 6.0f}, {7.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(rhs_row_major)))) + .ConsumeValueOrDie(); + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs")); + + Array2D expected({{15.0f, -2.0f}, {-25.0f, 34.0f}}); + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, error_spec_); + } +}; + +TYPED_TEST_CASE(SquareMatrixDot, TypesF16F32F64CF64); +XLA_TYPED_TEST(SquareMatrixDot, TypesFF) { this->TestImpl(false, false); } +XLA_TYPED_TEST(SquareMatrixDot, TypesFT) { this->TestImpl(false, true); } +XLA_TYPED_TEST(SquareMatrixDot, TypesTF) { this->TestImpl(true, false); } +XLA_TYPED_TEST(SquareMatrixDot, TypesTT) { this->TestImpl(true, true); } + struct DotTestParam { int m; int k; @@ -225,33 +276,39 @@ string PrintDotTestParam( } class ParametricDotTest : public DotOperationTest, - public ::testing::WithParamInterface {}; + public ::testing::WithParamInterface { + protected: + template + void TestImpl(); +}; -XLA_TEST_P(ParametricDotTest, TestF32) { +template +void ParametricDotTest::TestImpl() { DotTestParam param = GetParam(); - std::unique_ptr> dot_lhs_data = - MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( *dot_lhs_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); - std::unique_ptr> dot_rhs_data = - MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); - std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_rhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); + Layout rhs_layout = LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); + std::unique_ptr dot_rhs_lit = + Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); - std::unique_ptr> addend_data; + std::unique_ptr> addend_data; std::unique_ptr addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { - addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); + addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); addend_lit = Literal::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); @@ -259,24 +316,33 @@ XLA_TEST_P(ParametricDotTest, TestF32) { } ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); + auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), + builder.Parameter(0, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.k}, + MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), "dot_lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), + builder.Parameter(1, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.k, param.n}, + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), "dot_rhs")); if (param.has_addend) { result = builder.Add( - result, - builder.Parameter( - 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); + result, builder.Parameter( + 2, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.n}, + MinorToMajorForIsRowMajor(param.addend_row_major)), + "addend")); } - std::unique_ptr> expected; + std::unique_ptr> expected; if (param.has_addend) { expected = ReferenceUtil::ApplyElementwise2D( - std::plus(), + std::plus(), *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), *addend_data); } else { @@ -287,8 +353,11 @@ XLA_TEST_P(ParametricDotTest, TestF32) { if (param.has_addend) { args.push_back(addend_handle.get()); } - - ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); + ErrorSpec error_spec(0.3, 3e-3); + if (std::is_same::value) { + error_spec = ErrorSpec(0.3, 5e-3); + } + ComputeAndCompareR2(&builder, *expected, args, error_spec); } std::vector CreateDotTestParameters() { @@ -305,30 +374,77 @@ std::vector CreateDotTestParameters() { } }; + add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); + add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); + add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); + + return params; +} + +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(ParametricDotTest, TestF16) { TestImpl(); } +#endif +XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl(); } + +INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, + ::testing::ValuesIn(CreateDotTestParameters()), + PrintDotTestParam); + +class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { + public: + ParametricDotTestWithoutLayoutAssignment() { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "layout-assignment"); + } +}; + +std::vector CreateNoLayoutAssignmentDotTestParameters() { + std::vector params; + auto add_matrix_vector_dot_test = [&](int k, int n) { - for (bool has_addend : {false, true}) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, - /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (n != 1) { - params.push_back( - {/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, - /*has_addend=*/has_addend, /*addend_row_major=*/true}); + for (bool lhs_row_major : {true, false}) { + for (bool rhs_row_major : {true, false}) { + for (bool has_addend : {true, false}) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/true}); + if (has_addend) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/false}); + } + if (n != 1) { + params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/true}); + if (has_addend) { + params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/false}); + } + } + } } } }; - add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); - add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); - add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); - add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/4); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/3); add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); @@ -339,109 +455,60 @@ std::vector CreateDotTestParameters() { return params; } -INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, - ::testing::ValuesIn(CreateDotTestParameters()), - PrintDotTestParam); - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - TestSquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { - TestSquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { - TestSquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - TestSquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { - TestSquareMatrixDot(false, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { - TestSquareMatrixDot(false, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { - TestSquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { - TestSquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { - TestSquareMatrixDot(); -} - -template -void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, - bool rhs_row_major) { - auto lhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) - .ConsumeValueOrDie(); - auto rhs_handle = - client_ - ->TransferToServer(*Literal::CreateR2WithLayout( - {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) - .ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); - auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); - - Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); - - ComputeAndCompareR2( - &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - TestNonsquareMatrixDot(false, false); +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF16) { + TestImpl(); } - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - TestNonsquareMatrixDot(false, true); +#endif +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) { + TestImpl(); } - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - TestNonsquareMatrixDot(true, false); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - TestNonsquareMatrixDot(true, true); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { - TestNonsquareMatrixDot(); -} - -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { - TestNonsquareMatrixDot(false, false); +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF64) { + TestImpl(); } -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { - TestNonsquareMatrixDot(false, true); -} +INSTANTIATE_TEST_CASE_P( + DotTests, ParametricDotTestWithoutLayoutAssignment, + ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()), + PrintDotTestParam); -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { - TestNonsquareMatrixDot(true, false); -} +template +class NonsquareMatrixDot : public DotOperationTest { + public: + void TestImpl(bool lhs_row_major, bool rhs_row_major) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(lhs_row_major)))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateFromArrayWithLayout( + {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, + LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(rhs_row_major)))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs")); + + Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, error_spec_); + } +}; -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { - TestNonsquareMatrixDot(true, true); -} +TYPED_TEST_CASE(NonsquareMatrixDot, TypesF16F32F64CF64); +XLA_TYPED_TEST(NonsquareMatrixDot, TestFF) { this->TestImpl(false, false); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestFT) { this->TestImpl(false, true); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestTF) { this->TestImpl(true, false); } +XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = @@ -468,25 +535,35 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { - ComputationBuilder builder(client_, TestName()); - auto matrix1 = builder.ConstantR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix2 = builder.ConstantR2({{5.0, 6.0}, {7.0, 8.0}}); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { + using T = TypeParam; + + ComputationBuilder builder(this->client_, this->TestName()); + auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); + auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); auto matrix12 = builder.Dot(matrix1, matrix2); auto matrix21 = builder.Dot(matrix2, matrix1); builder.Add(matrix12, matrix21); - Array2D expected({{42.0, 56.0}, {74.0, 96.0}}); - ComputeAndCompareR2(&builder, expected, {}, error_spec_); + Array2D expected({{42.0f, 56.0f}, {74.0f, 96.0f}}); + this->template ComputeAndCompareR2(&builder, expected, {}, + this->error_spec_); } +template +class DotOperationTestForBatchMatMul : public DotOperationTest {}; +TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); + // Regression test for b/32055648. The root of the graph is a kFusion of 4 // bitcasts. Although bitcasts don't map to thunks, the root should still be // sync-dependent on bitcasts' operands. -XLA_TEST_F(DotOperationTest, BatchMatMul) { - ComputationBuilder builder(client_, TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y"); +XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { + using T = TypeParam; + ComputationBuilder builder(this->client_, this->TestName()); + auto x = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); + auto y = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "y"); auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2}); auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); @@ -507,29 +584,42 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { auto out_flat = builder.ConcatInDim(out_slices, 0); builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); - auto x_data = client_ - ->TransferToServer(*Literal::CreateR4( - {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, - {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) - .ConsumeValueOrDie(); - auto y_data = client_ - ->TransferToServer(*Literal::CreateR4( - {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, - {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) + auto x_data = this->client_ + ->TransferToServer(*Literal::CreateR4FromArray4D( + {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, + {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, + {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, + {{4000.0f, 400.0f}, {40.0f, 4.0f}}}})) .ConsumeValueOrDie(); + auto y_data = + this->client_ + ->TransferToServer(*Literal::CreateR4FromArray4D( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{11.0f, 22.0f}, {33.0f, 44.0f}}, + {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) + .ConsumeValueOrDie(); - ComputeAndCompareR4( + if (std::is_same::value) { + this->error_spec_ = ErrorSpec{0.0001, 1e-3}; + } + this->template ComputeAndCompareR4( &builder, /*expected=*/ - {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, - {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, - {x_data.get(), y_data.get()}, error_spec_); + {{{{1300.0f, 2400.0f}, {13.0f, 24.0f}}, + {{11400.0f, 13600.0f}, {114.0f, 136.0f}}}, + {{{42900.0f, 79200.0f}, {429.0f, 792.0f}}, + {{250800.0f, 299200.0f}, {2508.0f, 2992.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TEST_F(DotOperationTest, GeneralMatMul) { - ComputationBuilder builder(client_, TestName()); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { + using T = TypeParam; + + ComputationBuilder builder(this->client_, this->TestName()); + auto x = + builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); + auto y = + builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 2, 2}), "y"); DotDimensionNumbers dnums; dnums.add_lhs_contracting_dimensions(2); @@ -539,31 +629,34 @@ XLA_TEST_F(DotOperationTest, GeneralMatMul) { auto out = builder.DotGeneral(x, y, dnums); - auto x_data = client_ - ->TransferToServer(*Literal::CreateR3( - {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) - .ConsumeValueOrDie(); + auto x_data = + this->client_ + ->TransferToServer(*Literal::CreateR3FromArray3D( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) + .ConsumeValueOrDie(); - auto y_data = client_ - ->TransferToServer(*Literal::CreateR3( - {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) - .ConsumeValueOrDie(); + auto y_data = + this->client_ + ->TransferToServer(*Literal::CreateR3FromArray3D( + {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) + .ConsumeValueOrDie(); - ComputeAndCompareR3( + this->template ComputeAndCompareR3( &builder, /*expected=*/ - {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, - {x_data.get(), y_data.get()}, error_spec_); + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {x_data.get(), y_data.get()}, this->error_spec_); } -TEST_F(DotOperationTest, TransposeFolding) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { + using T = TypeParam; for (bool transpose_lhs : {false, true}) { for (bool transpose_rhs : {false, true}) { for (bool row_major : {false, true}) { - std::unique_ptr> lhs( - new Array2D({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}})); - std::unique_ptr> rhs( - new Array2D({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}})); + std::unique_ptr> lhs( + new Array2D({{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}})); + std::unique_ptr> rhs( + new Array2D({{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}})); if (transpose_lhs) { lhs = ReferenceUtil::TransposeArray2D(*lhs); @@ -572,22 +665,20 @@ TEST_F(DotOperationTest, TransposeFolding) { rhs = ReferenceUtil::TransposeArray2D(*rhs); } auto lhs_handle = - client_ - ->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - *lhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + *lhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = - client_ - ->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - *rhs, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(row_major)))) + this->client_ + ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + *rhs, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); + ComputationBuilder builder(this->client_, this->TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); auto lhs_arg = builder.Parameter( 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), "lhs"); @@ -602,24 +693,27 @@ TEST_F(DotOperationTest, TransposeFolding) { } auto result = builder.Dot(lhs_arg, rhs_arg); - Array2D expected({{26.0, 0.0}, {-12.0, 10.0}}); + Array2D expected({{26.0f, 0.0f}, {-12.0f, 10.0f}}); VLOG(1) << "TestTransposeFolding " << transpose_lhs << " " << transpose_rhs << " " << row_major; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - error_spec_); + this->template ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, + this->error_spec_); } } } } -TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { - auto prim_type = primitive_util::NativeToPrimitiveType(); +XLA_TYPED_TEST(DotOperationTest_F16F32F64, + DotOfConcatOptimizationWithConstLHS) { + using T = TypeParam; + auto prim_type = primitive_util::NativeToPrimitiveType(); - std::unique_ptr> constant_lhs_array(new Array2D( - {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, + {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); - ComputationBuilder builder(client_, TestName()); + ComputationBuilder builder(this->client_, this->TestName()); auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); @@ -630,78 +724,80 @@ TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { auto result = builder.Dot( lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); - std::unique_ptr> arg_0_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}})); - std::unique_ptr> arg_1_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); - std::unique_ptr> arg_2_value_array( - new Array2D({{1.0, 2.0}})); + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}})); + std::unique_ptr> arg_2_value_array(new Array2D({{1.0f, 2.0f}})); TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); - Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); - ComputeAndCompareR2( + Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); + this->template ComputeAndCompareR2( &builder, expected, - {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); -} - -TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { - auto prim_type = primitive_util::NativeToPrimitiveType(); - - std::unique_ptr> constant_rhs_array( - new Array2D({{1.0, 2.0}, - {3.0, 4.0}, - {5.0, 6.0}, - {6.0, 5.0}, - {4.0, 3.0}, - {2.0, 1.0}})); - - ComputationBuilder builder(client_, TestName()); + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, + this->error_spec_); +} + +XLA_TYPED_TEST(DotOperationTest_F16F32F64, + DotOfConcatOptimizationWithConstRHS) { + using T = TypeParam; + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}, + {6.0f, 5.0f}, + {4.0f, 3.0f}, + {2.0f, 1.0f}})); + + ComputationBuilder builder(this->client_, this->TestName()); auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); - auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); - auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType({2, 3}), "lhs_arg_1"); - auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType({2, 1}), "lhs_arg_2"); auto result = builder.Dot( builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); - std::unique_ptr> arg_0_value_array( - new Array2D({{1.0, 2.0}, {3.0, 4.0}})); - std::unique_ptr> arg_1_value_array( - new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); - std::unique_ptr> arg_2_value_array( - new Array2D({{1.0}, {2.0}})); + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0f, 2.0f}, {3.0f, 4.0f}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0f}, {2.0f}})); TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_0_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_1_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, - client_->TransferToServer( - *Literal::CreateR2FromArray2D(*arg_2_value_array))); + this->client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); - Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); - ComputeAndCompareR2( + Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); + this->template ComputeAndCompareR2( &builder, expected, - {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, + this->error_spec_); } + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 877dc7db0eec229a7119b3627f177a33ed0d971b..5f00c34002803553b9c17b4fce0abafda7369796 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" @@ -112,10 +111,8 @@ class DynamicSliceTest : public ClientLibraryTestBase { void TestR3Wrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR3( - {{{1, 2}, {3, 4}, {5, 6}}, - {{7, 8}, {9, 10}, {11, 12}}}, - {0, 2, 1}, {2, 1, 2}, - {{{6, 5}}, {{12, 11}}}); + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, + {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); } template @@ -137,9 +134,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -163,9 +160,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -189,9 +186,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -206,19 +203,19 @@ XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } -XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -281,6 +278,15 @@ XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { class DynamicUpdateSliceTest : public ClientLibraryTestBase { protected: + template + void TestR0() { + // Disable algebraic simplifier, otherwise the op will be replaced by a + // constant. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); + RunR0(0, 123, {}, 123); + } + template void TestR1() { // Slice at dimension start. @@ -341,6 +347,35 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); } + template + void RunR0(int input_value_int, int update_value_int, + const std::vector slice_starts, int expected_value_int) { + Literal input_value = + std::move(*Literal::CreateR0(input_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_value = + std::move(*Literal::CreateR0(update_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_value = + std::move(*Literal::CreateR0(expected_value_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + + ComputationBuilder builder(client_, TestName()); + // Initialize and transfer dynamic slice start indices parameter. + ComputationDataHandle starts; + std::unique_ptr start_data = CreateR1Parameter( + slice_starts, 0, "slice_starts", &builder, &starts); + // Build dynamic slice computation. + auto input = builder.ConstantLiteral(input_value); + auto update = builder.ConstantLiteral(update_value); + builder.DynamicUpdateSlice(input, update, starts); + // Run computation and compare against expected values. + ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()}); + } + template void RunR1(tensorflow::gtl::ArraySlice input_values_int, tensorflow::gtl::ArraySlice update_values_int, @@ -359,9 +394,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -390,9 +425,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -421,9 +456,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -474,13 +509,13 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } // Build dynamic slice computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer input parameter. - ComputationDataHandle input; + XlaOp input; std::unique_ptr input_data = CreateR3Parameter(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. - ComputationDataHandle update; + XlaOp update; std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1({index, 0, 0}); @@ -500,13 +535,18 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0BF16) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32R0) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } + // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) { @@ -672,7 +712,7 @@ void BM_DynamicSlice(int num_iters) { TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); - ComputationBuilder builder(client, "DynamicSlice"); + XlaBuilder builder("DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] auto input_literal = Literal::CreateR4( diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index 6fe7737de7af349dca2931b52d62dbc03b14e0b3..b28fe0c15a89a1331698a29f70b966380bd3fcb9 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -71,8 +71,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #ifdef XLA_TEST_BACKEND_CPU // TODO(b/73141998): The vectorized Log implementation gives results outside // our error spec in this range (these numbers are bitwise representations of - // floats expressed as a zero extended int64): - std::pair known_incorrect_range = {1, 8315654}; + // floats expressed as a zero extended int64). + std::pair known_incorrect_range = {1, 8388608}; #else std::pair known_incorrect_range = {0, 0}; #endif diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..90496d55e60b4f45fc2d46b2746f94d775cf9f94 --- /dev/null +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -0,0 +1,461 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +// NB! TODO(b/74360564): These tests do not test out of bounds behavior since +// that hasn't been specced yet. + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class GatherOperationTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* gather_indices) { + RunTest(hlo_text, {operand, gather_indices}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherV1) { + const string hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule TensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { + const string hlo_text = R"( +HloModule TensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { + const string hlo_text = R"( +HloModule TensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + ROOT gather = s32[2,1,1,2] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { + const string hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule TensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, DynamicSlice) { + const char* hlo_text = R"( +HloModule DynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + ROOT gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,0] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 0} +} +)"; + std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { + // Out of bounds indices must not crash, and the indices in range should + // produce the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for OOB accesses, + // we should get rid of the mask and check that backends produce the same + // value for OOB indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, NegativeIndex) { + // Negative indices must not crash, and the indices in range should produce + // the same values across all backends. + // + // TODO(b/74360564): Once we have a well defined semantics for negative + // accesses, we should get rid of the mask and check that backends produce the + // same value for negative indices too. + + const string hlo_text = R"( +HloModule BatchDynamicSlice + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + gather = s32[6,1,1]{2,1,0} gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} + gather_reshaped = s32[6]{0} reshape(gather) + in_bounds_mask = s32[6]{0} parameter(2) + ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR2( + {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr in_bounds_mask = + Literal::CreateR1({0, 1, 1, 0, 0, 1}); + + RunTest(hlo_text, + {operand.get(), gather_indices.get(), in_bounds_mask.get()}); +} + +XLA_TEST_F(GatherOperationTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), + output_window_dims={0,1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1,3,2} +} +)"; + std::unique_ptr operand = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ScalarResult) { + const char* hlo_text = R"( +HloModule ScalarResult + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + ROOT gather = s32[] gather(operand, index), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=0, + window_bounds={1} +} +)"; + std::unique_ptr operand = Literal::CreateR1({1, 2, 3, 4}); + std::unique_ptr gather_indices = Literal::CreateR0(1); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { + const string hlo_text = R"( +HloModule ZeroSizedResult + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[0] parameter(1) + ROOT gather = s32[0,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1, 3} +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +class GatherClientLibraryTest : public ClientLibraryTestBase {}; + +// TODO(b/30671675): Asynchronous execution on stream is not yet supported on +// GPU and CPU_PARALLEL. +XLA_TEST_F(GatherClientLibraryTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) { + // We create this HLO, but using the XlaBuilder API. + // + // ENTRY main { + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // ROOT gather = s32[2,3] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0}, + // gather_dims_to_operand_dims={0}, + // index_vector_dim=1, + // window_bounds={1, 3} + // } + + XlaBuilder builder("gather_basic"); + + Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {2}); + + auto operand = builder.Parameter(0, operand_shape, "operand"); + auto indices = builder.Parameter(1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_output_window_dims(1); + dim_numbers.add_elided_window_dims(0); + dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.set_index_vector_dim(1); + builder.Gather(operand, indices, dim_numbers, {1, 3}); + + std::vector expected = {}; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr operand_arg, + client_->TransferToServer(*Literal::CreateR2( + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr indices_arg, + client_->TransferToServer(*Literal::CreateR1({0, 2}))); + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + *execution_options.add_device_handles() = devices[0]; + TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, builder.Build()); + std::vector computation_instances = { + {computation, + {operand_arg.get(), indices_arg.get()}, + execution_options, + /*execution_profile=*/nullptr}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector> result_data, + client_->ExecuteParallel(computation_instances)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + client_->Transfer(*(result_data[0]))); + LiteralTestUtil::ExpectEqual( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc index eded2077fce965ab1c729c610764afa2228ca128..cf971dd61b71ad329b20b0bb7c16166126562681 100644 --- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc +++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" @@ -30,7 +29,7 @@ class HloMetadataTest : public LocalClientTestBase { metadata_.set_op_name("my_sum_op"); } - void BuildAddComputation(ComputationBuilder* builder) { + void BuildAddComputation(XlaBuilder* builder) { auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder->Add(x, y); @@ -40,7 +39,7 @@ class HloMetadataTest : public LocalClientTestBase { }; TEST_F(HloMetadataTest, MetadataPropagation) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); BuildAddComputation(&builder); builder.ClearOpMetadata(); @@ -61,7 +60,7 @@ TEST_F(HloMetadataTest, MetadataPropagation) { } TEST_F(HloMetadataTest, MetadataClearing) { - ComputationBuilder builder(local_client_, "add"); + XlaBuilder builder("add"); builder.SetOpMetadata(metadata_); // Some other pretend computation here. builder.ClearOpMetadata(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 9f5806c5e16c30cf198027cffab5f78c315cb957..21f71fc91bb84540e5347811cb4643a8aeda445c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -91,7 +91,7 @@ HloTestBase::HloTestBase() HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = MakeUnique(); + hlo_verifier_ = MakeUnique(/*allow_mixed_precision=*/true); } /* static */ @@ -115,6 +115,13 @@ StatusOr> HloTestBase::Execute( return test_runner_.Execute(std::move(module), arguments); } +StatusOr> HloTestBase::ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments, + /*run_hlo_passes=*/false); +} + std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { @@ -135,22 +142,15 @@ StatusOr> HloTestBase::MakeReferenceModule( "reference preprocessor must not modify the program shape"); } } - TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), - reference_module.get())); + TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status()); return std::move(reference_module); } -template StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { - static_assert( - std::is_same::value || - std::is_same, LiteralPtr>::value, - "The LiteralPtr type only accepts Literal* or std::unique_ptr."); - TF_RETURN_IF_ERROR( - VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); TF_ASSIGN_OR_RETURN(auto reference_module, MakeReferenceModule(*module, reference_preprocessor)); @@ -165,9 +165,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( error); } -template ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -179,9 +178,8 @@ template return result.ValueOrDie(); } -template ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, const ArraySlice arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -198,8 +196,14 @@ template const std::function& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompare>( - std::move(module), fake_arguments, error, reference_preprocessor); + + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + + return RunAndCompare(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( @@ -207,8 +211,13 @@ template const std::function& reference_preprocessor) { const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); - return RunAndCompareNoHloPasses>( - std::move(module), fake_arguments, error, reference_preprocessor); + std::vector fake_argument_ptrs; + c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const std::unique_ptr& literal) { return literal.get(); }); + + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, + reference_preprocessor); } ::testing::AssertionResult HloTestBase::RunAndCompare( @@ -267,6 +276,28 @@ template reference_preprocessor); } +HloComputation* HloTestBase::FindComputation(HloModule* module, + tensorflow::StringPiece name) { + auto it = c_find_if(module->computations(), + [&](HloComputation* c) { return c->name() == name; }); + if (it == module->computations().end()) { + return nullptr; + } + return *it; +} + +HloInstruction* HloTestBase::FindInstruction(HloModule* module, + tensorflow::StringPiece name) { + for (const HloComputation* c : module->computations()) { + auto it = c_find_if(c->instructions(), + [&](HloInstruction* i) { return i->name() == name; }); + if (it != c->instructions().end()) { + return *it; + } + } + return nullptr; +} + Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4aea9fc9fd027231106e529eb16bcd43f23fbe1c..3e8e2360bb3a87e127920cd222803c0f7b9161f4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -44,7 +44,7 @@ namespace xla { // enables, for one, explicitly building a graph of HLO instructions to run. // // This can also be used to write text/file-based test cases. Note that the test -// target is responsible for linking the needed backends. A covenient way to do +// target is responsible for linking the needed backends. A convenient way to do // this is to make it an xla_test: it will generate test targets linking with // the respective backends, which will be used as the test backend; the // interpreter backend is already linked with hlo_test_base so it will be the @@ -98,14 +98,19 @@ class HloTestBase : public ::testing::Test { std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); + // Same as above, except the module will be executed without running any HLO + // passes on it. + StatusOr> ExecuteNoHloPasses( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + std::unique_ptr ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); // Executes the given hlo module on two backends and compares results. // - // 'arguments': the input of the hlo module. The LiteralPtr type accepts - // Literal* or std::unique_ptr. + // 'arguments': the input of the hlo module. // // 'error': if has value, expects the results to be near (within the error // bound). Otherwise, expects the results to be equal. @@ -114,20 +119,18 @@ class HloTestBase : public ::testing::Test { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - template ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Same as above, except that the module will be executed without Hlo // optimization. - template ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -197,6 +200,15 @@ class HloTestBase : public ::testing::Test { ->Clear(); } + // Gets the computation/instruction from the given module with the given name. + // + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + HloComputation* FindComputation(HloModule* module, + tensorflow::StringPiece name); + HloInstruction* FindInstruction(HloModule* module, + tensorflow::StringPiece name); + // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } @@ -223,10 +235,9 @@ class HloTestBase : public ::testing::Test { // Runs the module on two platforms with or without running hlo passes and // compares the results. Returns whether the results are near or equal. If any // error happens before the results are computed, returns the error status. - template StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::ArraySlice arguments, const tensorflow::gtl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 506091ddd8d1d8e6519525bb7031f4e8b296b5fb..da4cf4ae0c31bc194cd2ec9b845df36afbde69b0 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -40,18 +41,22 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; - } + VerifyModule(); } HloTestBase::TearDown(); } +void HloVerifiedTestBase::VerifyModule() { + HloVerifier verifier; + xla::StatusOr mutated = verifier.Run(module_.get()); + if (!mutated.ok()) { + ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); + } else { + EXPECT_FALSE(mutated.ValueOrDie()) + << "HloVerifier should never mutate the HloModule"; + } +} + HloModule& HloVerifiedTestBase::module() { if (!module_) { module_ = CreateNewModule(); @@ -59,4 +64,10 @@ HloModule& HloVerifiedTestBase::module() { return *module_; } +void HloVerifiedTestBase::ParseAndVerifyModule( + tensorflow::StringPiece hlo_text) { + CHECK(!module_) << "Called ParseModule when test already has a module."; + TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); + VerifyModule(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 492688bf7d682cf991cb8c09399492a0437f651b..e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,6 +44,7 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. @@ -55,6 +56,7 @@ class HloVerifiedTestBase : public HloTestBase { std::unique_ptr module_; // Lazily populated. Access via module(). std::unique_ptr shape_verifier_; bool tear_down_called_ = false; + void VerifyModule(); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 5aa71a9261dbd414d1499f15c9b83cd63b634b49..81630df34c58526b6d41492b2b4b3892a02a21c2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -209,6 +209,11 @@ template <> return CompareFloatsBitwiseEqual(lhs, rhs); } template <> +::testing::AssertionResult CompareEqual(Eigen::half lhs, + Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> ::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 99514baf23cafe61adc28a30dfdfe2691ab82d32..3023df47cda33f5d11abc921fd0355d48f761107 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -49,11 +50,11 @@ void LLVMIRGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - ASSERT_TRUE(CompileToExecutable(std::move(hlo_module)).ok()); + TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status()); ResetIrHook(); StatusOr filecheck_result = RunFileCheck(ir_, pattern); - ASSERT_TRUE(filecheck_result.ok()); + TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(filecheck_result.ValueOrDie()); } diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2b0f7e6e80c48435ca55432a2afa3b6d69162625..efe6cc67872713a8aeecc11aeafe4902676817a6 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,18 +52,18 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} ----> (add) // / // 1.0f ---------/ - Computation CreateAdderToOne() { - ComputationBuilder mapped_builder(client_, TestName()); + XlaComputation CreateAdderToOne() { + XlaBuilder mapped_builder(TestName()); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = mapped_builder.ConstantR0(1.0); - auto adder_to_one = mapped_builder.Add(x, one); + mapped_builder.Add(x, one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); } - Computation CreateMax() { - ComputationBuilder b(client_, TestName()); + XlaComputation CreateMax() { + XlaBuilder b(TestName()); auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); b.Max(lhs, rhs); @@ -73,8 +75,8 @@ class MapTest : public ClientLibraryTestBase { // Creates a computation that accepts an F32 and returns T(1) (ignoring the // argument). template - Computation CreateScalarOne() { - ComputationBuilder mapped_builder(client_, "scalar_one"); + XlaComputation CreateScalarOne() { + XlaBuilder mapped_builder("scalar_one"); (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); mapped_builder.ConstantR0(1); auto computation_status = mapped_builder.Build(); @@ -87,11 +89,11 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} ----> (mul) // / // 2.0f ---------/ - Computation CreateMulByTwo() { - ComputationBuilder mapped_builder(client_, TestName()); + XlaComputation CreateMulByTwo() { + XlaBuilder mapped_builder(TestName()); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto two = mapped_builder.ConstantR0(2.0); - auto mul_by_two = mapped_builder.Mul(x, two); + mapped_builder.Mul(x, two); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -105,12 +107,12 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} ----> (add) ----> (mul) // / // 1.0f ---------/ - Computation CreateAdderToOneTimesItself() { - ComputationBuilder mapped_builder(client_, TestName()); + XlaComputation CreateAdderToOneTimesItself() { + XlaBuilder mapped_builder(TestName()); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = mapped_builder.ConstantR0(1.0); auto adder_to_one = mapped_builder.Add(x, one); - auto result = mapped_builder.Mul(x, adder_to_one); + mapped_builder.Mul(x, adder_to_one); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -122,12 +124,13 @@ class MapTest : public ClientLibraryTestBase { // x {R0F32} -----------> (map) ----> (add) // / / // embedded_computation --/ n --/ - Computation CreateMapPlusN(const Computation& embedded_computation, float n) { - ComputationBuilder builder(client_, TestName()); + XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation, + float n) { + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto map = builder.Map({x}, embedded_computation, {}); auto constant_n = builder.ConstantR0(n); - auto add = builder.Add(map, constant_n); + builder.Add(map, constant_n); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -135,11 +138,11 @@ class MapTest : public ClientLibraryTestBase { // Creates a binary function with signature (F32, F32) -> Pred // defined by (x, y) -> x > y. - Computation CreateGt() { - ComputationBuilder b(client_, "Gt"); + XlaComputation CreateGt() { + XlaBuilder b("Gt"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); - auto gt = b.Gt(x, y); + b.Gt(x, y); auto computation_status = b.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -152,13 +155,13 @@ class MapTest : public ClientLibraryTestBase { // y {R0F32} ----> (add) ---> (add) // / // z {R0F32} ---------------/ - Computation CreateTernaryAdder() { - ComputationBuilder mapped_builder(client_, "TernaryAdder"); + XlaComputation CreateTernaryAdder() { + XlaBuilder mapped_builder("TernaryAdder"); auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z"); auto xy = mapped_builder.Add(x, y); - auto xyz = mapped_builder.Add(xy, z); + mapped_builder.Add(xy, z); auto computation_status = mapped_builder.Build(); TF_CHECK_OK(computation_status.status()); return computation_status.ConsumeValueOrDie(); @@ -167,13 +170,13 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {}); + builder.Map({param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -181,13 +184,13 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {0}); + builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -195,55 +198,55 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {0}); + builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); } TEST_F(MapTest, MapEachF32ElementToS32Constant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne(), {0}); + builder.Map({param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } TEST_F(MapTest, MapEachF32ElementToU32Constant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne(), {0}); + builder.Map({param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0}); + builder.Map({param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -253,14 +256,14 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); + builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -269,7 +272,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = @@ -277,7 +280,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { auto param = builder.Parameter(0, param0_literal->shape(), "param0"); auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); - auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); + builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -285,14 +288,14 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne(), {0, 1}); + builder.Map({param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -317,18 +320,18 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { auto embed2 = CreateMapPlusN(embed1, 2.0); auto embed3 = CreateMapPlusN(embed1, 4.0); - ComputationBuilder embed4_builder(client_, "embed4"); + XlaBuilder embed4_builder("embed4"); auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); - auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); + embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); auto embed4 = embed4_status.ConsumeValueOrDie(); auto embed5 = CreateMapPlusN(embed2, 6.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto constant_42 = builder.ConstantR0(42.0); auto constant_7 = builder.ConstantR0(7.0); auto map_42 = builder.Map({constant_42}, embed5, {}); @@ -359,7 +362,8 @@ TEST_F(MapTest, VersionedEmbeddedComputation) { // Add another Add(1) operation to the existing embedded computation. This // requires using the stub interface because the ComputationBuilder does not - // allow modification to the Computation objects after they have been built. + // allow modification to the XlaComputation objects after they have been + // built. BinaryOpRequest request; request.set_binop(BINOP_ADD); *request.mutable_lhs() = adder_to_one; @@ -381,7 +385,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) { TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = @@ -393,8 +397,7 @@ TEST_F(MapTest, MapBinaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, - CreateScalarAddComputation(F32, &builder), {0}); + builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -404,7 +407,7 @@ TEST_F(MapTest, MapBinaryAdder) { // Adds two rank-2 arrays with different layouts. This test exercises a path // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = @@ -417,8 +420,8 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, - CreateScalarAddComputation(S32, &builder), {0, 1}); + builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1}); Array2D expected(2, 2); expected(0, 0) = 11; @@ -430,7 +433,7 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { } XLA_TEST_F(MapTest, AddR3_3x0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = @@ -443,8 +446,8 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, - CreateScalarAddComputation(S32, &builder), {0, 1, 2}); + builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder), + {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -452,7 +455,7 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = @@ -469,7 +472,7 @@ TEST_F(MapTest, MapTernaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); + builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -479,24 +482,24 @@ TEST_F(MapTest, MapTernaryAdder) { TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto gt = CreateGt(); b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt, {0}); ComputeAndCompareR1(&b, {false, true}, {}); } TEST_F(MapTest, NestedBinaryMap) { - Computation max_with_square; + XlaComputation max_with_square; { // max_with_square(x) = do max(x, x^2) via a map. - ComputationBuilder b(client_, "max_with_square"); + XlaBuilder b("max_with_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); b.Map({x, b.Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); b.Map({input}, max_with_square, {0}); ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); @@ -505,13 +508,13 @@ TEST_F(MapTest, NestedBinaryMap) { TEST_F(MapTest, MapOperantionWithBuildError) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported // type combination (F32 + U16) to test that the error is reported to the - // outermost ComputationBuilder. - ComputationBuilder builder(client_, TestName()); + // outermost XlaBuilder. + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("ErrorAdd"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y"); - auto adder = sub_builder->Add(x, y); + sub_builder->Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = @@ -525,13 +528,13 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, error_add, {0}); + builder.Map({param0, param1}, error_add, {0}); - StatusOr computation_status = builder.Build(); + StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: binary op BINOP_ADD with " + ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " "different element types: f32[] and u16[]")); } @@ -545,7 +548,7 @@ using MapTestWithFullOpt = ClientLibraryTestBase; // to have issues with such patterns and maybe invalidate the pointer to entry // computation. TEST_F(MapTestWithFullOpt, MapScalarPower) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); @@ -572,7 +575,7 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { // Regression test for b/35786417, where the inliner would not notice the change // of parameter order inside the map. TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); @@ -598,7 +601,7 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { // Regression test for b/35786417, where the inliner would CHECK-fail due to the // mul inside the map having more parameters than the map does. TEST_F(MapTestWithFullOpt, MapSquare) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto sub_builder = builder.CreateSubBuilder("power"); auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 6c86dd5b9ef673c9facffafa37e00a859ce82010..c42f71388baba73e08a361d817e41b03e03bf133 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -29,6 +29,8 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" @@ -38,258 +40,223 @@ limitations under the License. namespace xla { namespace { -class MatOpsSimpleTest : public ClientLibraryTestBase { - protected: - Computation BuildSum() { - // sum(x, y) = x + y - ComputationBuilder builder(client_, "sum"); - auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto y_value = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y_value"); - builder.Add(x_value, y_value); - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return computation_status.ConsumeValueOrDie(); - } - - void TestLinspaceMax(int64 rows, int64 cols) { - float from = -128.0, to = 256.0; - std::unique_ptr> alhs = - MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, 1.0); - - ComputationBuilder builder( - client_, - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); - auto lhs = builder.ConstantR2FromArray2D(*alhs); - auto rhs = builder.ConstantR2FromArray2D(*arhs); - auto max = builder.Max(lhs, rhs); - - Array2D aexpected(rows, cols); - for (int row = 0; row < rows; ++row) { - for (int col = 0; col < cols; ++col) { - aexpected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); - } - } - - ComputeAndCompareR2(&builder, aexpected, {}, ErrorSpec(1e-6)); - } -}; - -TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { - ComputationBuilder builder(client_, "exp_2x2"); - auto data = builder.ConstantR2({ - {1.0, 0.0}, // row 0 - {-1.0, 0.5}, // row 1 +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TypesF16F32 = ::testing::Types; +#else +using TypesF16F32 = ::testing::Types; +#endif + +class MatOpsSimpleTest : public ClientLibraryTestBase {}; + +template +class MatOpsSimpleTest_F16F32 : public MatOpsSimpleTest {}; + +// TODO(bixia): This test for F16 failed on GPU 02-25-2018. +#ifdef XLA_TEST_BACKEND_GPU +TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, ::testing::Types); +#else +TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); +#endif + +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { + using T = TypeParam; + ComputationBuilder builder(this->client_, "exp_2x2"); + auto data = builder.ConstantR2FromArray2D({ + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 }); builder.Exp(data); std::unique_ptr expected = - Literal::CreateR2({{2.71828, 1.00000}, // row 0 - {0.36788, 1.64872}}); // row 1 + Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 + {0.36788f, 1.64872f}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->template ComputeAndCompareLiteral(&builder, *expected, {}, + ErrorSpec(1e-5)); } -TEST_F(MatOpsSimpleTest, MapTwoByTwo) { +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { + using T = TypeParam; Computation add_half; { // add_half(x) = x + 0.5 - ComputationBuilder builder(client_, "add_half"); + ComputationBuilder builder(this->client_, "add_half"); auto x_value = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); - auto half = builder.ConstantR0(0.5); + builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); + auto half = builder.ConstantR0(static_cast(0.5)); builder.Add(x_value, half); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); add_half = computation_status.ConsumeValueOrDie(); } - ComputationBuilder builder(client_, "map_2x2"); - auto data = builder.ConstantR2({ - {1.0, 0.0}, // row 0 - {-1.0, 0.5}, // row 1 + ComputationBuilder builder(this->client_, "map_2x2"); + auto data = builder.ConstantR2FromArray2D({ + {1.0f, 0.0f}, // row 0 + {-1.0f, 0.5f}, // row 1 }); auto map = builder.Map({data}, add_half, {0, 1}); std::unique_ptr expected = - Literal::CreateR2({{1.5, 0.5}, // row 0 - {-0.5, 1.0}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 + {-0.5f, 1.0f}}); // row 1 + this->template ComputeAndCompareLiteral(&builder, *expected, {}, + ErrorSpec(1e-5)); } -TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { - ComputationBuilder builder(client_, "max_2x2"); - auto lhs = builder.ConstantR2({ - {7.0, 2.0}, // row 0 - {3.0, -4.0}, // row 1 +XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { + using T = TypeParam; + ComputationBuilder builder(this->client_, "max_2x2"); + auto lhs = builder.ConstantR2FromArray2D({ + {7.0f, 2.0f}, // row 0 + {3.0f, -4.0f}, // row 1 }); - auto rhs = builder.ConstantR2({ - {5.0, 6.0}, // row 0 - {1.0, -8.0}, // row 1 + auto rhs = builder.ConstantR2FromArray2D({ + {5.0f, 6.0f}, // row 0 + {1.0f, -8.0f}, // row 1 }); auto max = builder.Max(lhs, rhs); std::unique_ptr expected = - Literal::CreateR2({{7.0, 6.0}, // row 0 - {3.0, -4.0}}); // row 1 - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 + {3.0f, -4.0f}}); // row 1 + this->template ComputeAndCompareLiteral(&builder, *expected, {}, + ErrorSpec(1e-6)); } -TEST_F(MatOpsSimpleTest, Max1x1Linspace) { TestLinspaceMax(1, 1); } - -TEST_F(MatOpsSimpleTest, Max2x2Linspace) { TestLinspaceMax(2, 2); } - -TEST_F(MatOpsSimpleTest, Max3x3Linspace) { TestLinspaceMax(3, 3); } - -TEST_F(MatOpsSimpleTest, Max4x4Linspace) { TestLinspaceMax(4, 4); } - -TEST_F(MatOpsSimpleTest, Max6x6Linspace) { TestLinspaceMax(6, 6); } - -TEST_F(MatOpsSimpleTest, Max8x8Linspace) { TestLinspaceMax(8, 8); } - -TEST_F(MatOpsSimpleTest, Max12x12Linspace) { TestLinspaceMax(12, 12); } - -TEST_F(MatOpsSimpleTest, Max16x16Linspace) { TestLinspaceMax(16, 16); } +struct TestLinspaceMaxParam { + int64 rows; + int64 cols; +}; -TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); } +class TestLinspaceMaxParametric + : public MatOpsSimpleTest, + public ::testing::WithParamInterface { + public: + template + void TestImpl() { + TestLinspaceMaxParam param = GetParam(); + int64 rows = param.rows; + int64 cols = param.cols; + float from = -128.0, to = 256.0; + std::unique_ptr> alhs = + MakeLinspaceArray2D(from, to, rows, cols); + auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); -TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); } + ComputationBuilder builder( + client_, + tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + auto lhs = builder.ConstantR2FromArray2D(*alhs); + auto rhs = builder.ConstantR2FromArray2D(*arhs); + auto max = builder.Max(lhs, rhs); -class MatOpsDotAddTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; - -TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { - bool row_major = std::get<0>(GetParam()); - bool add_lhs = std::get<1>(GetParam()); - bool transpose = std::get<2>(GetParam()); - Array2D lhs({{1.0, 2.0}, {3.0, 4.0}}); - Array2D rhs({{10.0, 11.0}, {12.0, 13.0}}); - - auto minor_to_major = [](bool row_major) -> std::vector { - return {row_major ? 1 : 0, row_major ? 0 : 1}; - }; - - auto prim_type = primitive_util::NativeToPrimitiveType(); - Shape lhs_shape = - ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); - Shape rhs_shape = - ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - - TF_ASSERT_OK_AND_ASSIGN( - auto lhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSERT_OK_AND_ASSIGN( - auto rhs_handle, - client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - - ComputationBuilder builder(client_, TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); - auto lhs_mat_arg = lhs_arg; - if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); - } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); - Array2D expected; - if (add_lhs) { - result = builder.Add(result, lhs_arg); - if (transpose) { - expected = Array2D({{47, 52}, {71, 78}}); - } else { - expected = Array2D({{35, 39}, {81, 89}}); + Array2D expected(rows, cols); + for (int row = 0; row < rows; ++row) { + for (int col = 0; col < cols; ++col) { + expected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col)); + } } - } else { - result = builder.Add(result, rhs_arg); - if (transpose) { - expected = Array2D({{56, 61}, {80, 87}}); - } else { - expected = Array2D({{44, 48}, {90, 98}}); + ErrorSpec error_spec(1e-6); + if (std::is_same::value) { + error_spec = ErrorSpec(1e-6, 2e-4); } + ComputeAndCompareR2(&builder, expected, {}, error_spec); } +}; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(1e-6)); +string PrintTestLinspaceMaxParam( + const ::testing::TestParamInfo& test_param) { + const TestLinspaceMaxParam& param = test_param.param; + return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); } -INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, - ::testing::Combine(::testing::Bool(), ::testing::Bool(), - ::testing::Bool())); +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +// TODO(bixia): This test failed on GPU 02-25-2018 +#ifdef XLA_TEST_BACKEND_CPU +XLA_TEST_P(TestLinspaceMaxParametric, TestF16) { TestImpl(); } +#endif +#endif +XLA_TEST_P(TestLinspaceMaxParametric, TestF32) { TestImpl(); } + +INSTANTIATE_TEST_CASE_P( + TestLinspaceMax, TestLinspaceMaxParametric, + ::testing::Values(TestLinspaceMaxParam{1, 1}, TestLinspaceMaxParam{2, 2}, + TestLinspaceMaxParam{3, 3}, TestLinspaceMaxParam{4, 4}, + TestLinspaceMaxParam{6, 6}, TestLinspaceMaxParam{8, 8}, + TestLinspaceMaxParam{12, 12}, + TestLinspaceMaxParam{16, 16}, TestLinspaceMaxParam{32, 8}, + TestLinspaceMaxParam{64, 8}), + PrintTestLinspaceMaxParam); -class MatOpsDotAddTest_bf16 +class MatOpsDotAddTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; - -TEST_P(MatOpsDotAddTest_bf16, Dot_Add_2x2_2x2) { - bool row_major = std::get<0>(GetParam()); - bool add_lhs = std::get<1>(GetParam()); - bool transpose = std::get<2>(GetParam()); - Array2D lhs( - {{bfloat16(1.0f), bfloat16(2.0f)}, {bfloat16(3.0), bfloat16(4.0)}}); - Array2D rhs( - {{bfloat16(10.0f), bfloat16(11.0f)}, {bfloat16(12.0f), bfloat16(13.0f)}}); - - auto minor_to_major = [](bool row_major) -> std::vector { - return {row_major ? 1 : 0, row_major ? 0 : 1}; - }; - - auto prim_type = primitive_util::NativeToPrimitiveType(); - Shape lhs_shape = - ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); - Shape rhs_shape = - ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); - - TF_ASSERT_OK_AND_ASSIGN( - auto lhs_handle, - client_->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - TF_ASSERT_OK_AND_ASSIGN( - auto rhs_handle, - client_->TransferToServer( - *Literal::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - - ComputationBuilder builder(client_, TestName()); - auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); - auto lhs_mat_arg = lhs_arg; - if (transpose) { - lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); - } - auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); - auto result = builder.Dot(lhs_mat_arg, rhs_arg); - Array2D expected; - if (add_lhs) { - result = builder.Add(result, lhs_arg); + public ::testing::WithParamInterface> { + public: + template + void TestImpl() { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + bool transpose = std::get<2>(GetParam()); + Array2D lhs({{1.0f, 2.0f}, {3.0f, 4.0f}}); + Array2D rhs({{10.0f, 11.0f}, {12.0f, 13.0f}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSERT_OK_AND_ASSIGN( + auto lhs_handle, + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSERT_OK_AND_ASSIGN( + auto rhs_handle, + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + ComputationBuilder builder(client_, TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_mat_arg = lhs_arg; if (transpose) { - expected = Array2D( - {{bfloat16(47), bfloat16(52)}, {bfloat16(71), bfloat16(78)}}); - } else { - expected = Array2D( - {{bfloat16(35), bfloat16(39)}, {bfloat16(81), bfloat16(89)}}); + lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); } - } else { - result = builder.Add(result, rhs_arg); - if (transpose) { - expected = Array2D( - {{bfloat16(56), bfloat16(61)}, {bfloat16(80), bfloat16(87)}}); + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_mat_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + if (transpose) { + expected = Array2D({{47.0f, 52.0f}, {71.0f, 78.0f}}); + } else { + expected = Array2D({{35.0f, 39.0f}, {81.0f, 89.0f}}); + } } else { - expected = Array2D( - {{bfloat16(44), bfloat16(48)}, {bfloat16(90), bfloat16(98)}}); + result = builder.Add(result, rhs_arg); + if (transpose) { + expected = Array2D({{56.0f, 61.0f}, {80.0f, 87.0f}}); + } else { + expected = Array2D({{44.0f, 48.0f}, {90.0f, 98.0f}}); + } } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); } +}; - ComputeAndCompareR2(&builder, expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(1e-6)); -} +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2BF16) { TestImpl(); } +#ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2F16) { TestImpl(); } +#endif +XLA_TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2F32) { TestImpl(); } -INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest_bf16, +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool())); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 8cef8dd34dc7b16b1e58ded67d6b6a4ba79f20db..ce295b832d79e4f00656f2893c2ba1162693dd73 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -85,7 +85,7 @@ class PadTestFloat : public PadTest, // Tests a Pad() with a zero-element input and output. XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Set up the padding configuration {low: 0, high: 0, interior: 0}. PaddingConfig padding_config; auto dimension = padding_config.add_dimensions(); @@ -100,7 +100,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { // Tests a Pad() with a zero-element input but a non-zero-element output. XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; auto dimension = padding_config.add_dimensions(); @@ -115,7 +115,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { } XLA_TEST_P(PadTestFloat, Pad1DS3Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; auto dimension = padding_config.add_dimensions(); @@ -130,7 +130,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { } XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, @@ -138,7 +138,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { } TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ {1.0f, 2.0f}, // row 0 @@ -162,7 +162,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { } TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -181,7 +181,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { } TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); PaddingConfig padding_config; auto dimension0 = padding_config.add_dimensions(); @@ -223,7 +223,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { } XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); PaddingConfig padding_config; auto dimension0 = padding_config.add_dimensions(); @@ -266,7 +266,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { } XLA_TEST_F(PadTest, Pad4DU8Array) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ {1, 2}, // row 0 @@ -290,7 +290,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { } XLA_TEST_F(PadTest, Pad4DPredArray) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Since bool is currently not well supported, use Broadcast operation to // create the operand for Pad. @@ -317,7 +317,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { } XLA_TEST_P(PadTestFloat, Large2DPad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto ones = MakeUnique>(4, 4); ones->Fill(1.0f); @@ -329,15 +329,14 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, AllTypes2DPad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; @@ -352,15 +351,14 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } XLA_TEST_P(PadTestFloat, High2DPad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 129; constexpr int64 in_cols = 129; @@ -378,8 +376,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -387,7 +384,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { } XLA_TEST_P(PadTestFloat, NegativePadding2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 129; constexpr int64 in_cols = 129; @@ -406,8 +403,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -415,7 +411,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { } XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); constexpr int64 in_rows = 8; constexpr int64 in_cols = 11; @@ -434,8 +430,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), - padding_config); + b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -444,20 +439,19 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { // Regression test for b/31827337. XLA_TEST_P(PadTestFloat, ReducePad) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto ones = MakeUnique>(2, 2, 2, 2); ones->Fill(1.0); auto input = AddParam(*ones, &b); - Computation add = CreateScalarAddComputation(FloatType(), &b); + XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - auto padded = b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), - padding_config); + b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index dc7ce3253cee255a7949326fa5b49fc8917432b8..b311785449f1774c3bc1e4d7ad35c2866e3b4061 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" @@ -228,15 +228,14 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { // This is required for proper handling of NaN values. SetFastMathDisabled(true); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({input_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a = builder.Parameter(0, a_literal->shape(), "a"); - auto reduce_precision = - builder.ReducePrecision(a, exponent_bits, mantissa_bits); + builder.ReducePrecision(a, exponent_bits, mantissa_bits); ComputeAndCompareR1(&builder, expected_values, {a_data.get()}); } @@ -252,7 +251,7 @@ class ReducePrecisionInsertionTest : public ClientLibraryTestBase {}; // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -265,7 +264,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // Near 1.0, Log(x) approximates x - 1; this lets us confirm that the // reduce-precision operation showed up in the correct place in the // graph. - auto log = builder.Log(abs); + builder.Log(abs); // Insert precision-reduction after the Abs(x) operation, rounding that // result to exactly 1.0f. @@ -281,7 +280,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -290,7 +289,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass after operation fusion, suffixing kAbs operations. This // should not see into the fusion nodes and thus should not affect the @@ -307,7 +306,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -316,7 +315,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass after operation fusion, suffixing kFusion operations. auto reduce_precision_pass = execution_options_.mutable_debug_options() @@ -331,7 +330,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -340,7 +339,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass suffixing fusion nodes containing kCos operations. This // should have no effect. @@ -356,7 +355,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // The interpreter has no fusion pass, so skip this test. XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); std::unique_ptr a_data = @@ -365,7 +364,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, // These two operations should be fused by any reasonable backend. auto abs = builder.Abs(a); - auto neg = builder.Neg(abs); + builder.Neg(abs); // Add a pass suffixing fusion nodes containing kAbs operations. This // should see the kAbs operation within the above fusion node. diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 50d7b5074d201d2292cf90224ef4cd37efdbb8d3..423ccadb5b3b7df950824349737a833c08870d77 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -39,6 +39,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -57,6 +60,11 @@ limitations under the License. namespace xla { namespace { +using FuncGeneratorForType = Computation (*)(PrimitiveType, + ComputationBuilder*); + +using FuncGenerator = Computation (*)(ComputationBuilder*); + class ReduceTest : public ClientLibraryTestBase { protected: ReduceTest() { @@ -497,21 +505,18 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { // Test that algebraic simplifier does not incorrectly fold a transpose into a // reduction operation. XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); - ComputationDataHandle input = builder.Parameter(0, input_shape, "input"); - ComputationDataHandle zero = builder.ConstantR0(0.0); - ComputationDataHandle transpose = - builder.Transpose(input, /*permutation=*/{1, 0, 2}); - ComputationDataHandle reduce = - builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + XlaOp input = builder.Parameter(0, input_shape, "input"); + XlaOp zero = builder.ConstantR0(0.0); + XlaOp transpose = builder.Transpose(input, /*permutation=*/{1, 0, 2}); + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, reduce, {std::move(*input_data)}, - ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -755,53 +760,57 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { } XLA_TEST_F(ReduceTest, VectorizedReduce_Add) { - RunVectorizedReduceTest(CreateScalarAddComputation, - [](float a, float b) { return a + b; }, - [](int32 a, int32 b) { - return static_cast(static_cast(a) + - static_cast(b)); - }, - [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); + RunVectorizedReduceTest( + static_cast(CreateScalarAddComputation), + [](float a, float b) { return a + b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) + + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a + b; }, 0.0, 0, 0); } XLA_TEST_F(ReduceTest, VectorizedReduce_Multiply) { - RunVectorizedReduceTest(CreateScalarMultiplyComputation, - [](float a, float b) { return a * b; }, - [](int32 a, int32 b) { - return static_cast(static_cast(a) * - static_cast(b)); - }, - [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); + RunVectorizedReduceTest( + static_cast(CreateScalarMultiplyComputation), + [](float a, float b) { return a * b; }, + [](int32 a, int32 b) { + return static_cast(static_cast(a) * + static_cast(b)); + }, + [](uint32 a, uint32 b) { return a * b; }, 1.0, 1, 1); } XLA_TEST_F(ReduceTest, VectorizedReduce_Max) { - RunVectorizedReduceTest(CreateScalarMaxComputation, - [](float a, float b) { return std::max(a, b); }, - [](int32 a, int32 b) { return std::max(a, b); }, - [](uint32 a, uint32 b) { return std::max(a, b); }, - std::numeric_limits::min(), - std::numeric_limits::min(), - std::numeric_limits::min()); + RunVectorizedReduceTest( + static_cast(CreateScalarMaxComputation), + [](float a, float b) { return std::max(a, b); }, + [](int32 a, int32 b) { return std::max(a, b); }, + [](uint32 a, uint32 b) { return std::max(a, b); }, + std::numeric_limits::min(), std::numeric_limits::min(), + std::numeric_limits::min()); } XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { - RunVectorizedReduceTest(CreateScalarMinComputation, - [](float a, float b) { return std::min(a, b); }, - [](int32 a, int32 b) { return std::min(a, b); }, - [](uint32 a, uint32 b) { return std::min(a, b); }, - std::numeric_limits::max(), - std::numeric_limits::max(), - std::numeric_limits::max()); + RunVectorizedReduceTest( + static_cast(CreateScalarMinComputation), + [](float a, float b) { return std::min(a, b); }, + [](int32 a, int32 b) { return std::min(a, b); }, + [](uint32 a, uint32 b) { return std::min(a, b); }, + std::numeric_limits::max(), std::numeric_limits::max(), + std::numeric_limits::max()); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { RunVectorizedReduceTestForType( - CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); + static_cast(CreateScalarAndComputation), + [](bool a, bool b) { return a && b; }, true); } XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { RunVectorizedReduceTestForType( - CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); + static_cast(CreateScalarOrComputation), + [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, @@ -884,5 +893,78 @@ XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) { RunR2ToR1PredTest(/*and_reduce=false*/ false, /*rows=64*/ 64); } +// Tests reductions with different initial values. There's no test macro that +// combines TYPED_TEST and TYPED_P, so we have to do it manually. +class ReduceInitializerTest : public ReduceTest { + protected: + template + void DoTest(T initializer, int num_elems) { + ComputationBuilder builder(client_, TestName()); + Computation max_fn = CreateScalarMaxComputation( + primitive_util::NativeToPrimitiveType(), &builder); + + auto init = builder.ConstantR0(initializer); + std::vector input_arr(num_elems, std::numeric_limits::lowest()); + auto input_literal = Literal::CreateR1(input_arr); + auto input_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init, + max_fn, {0}); + + ComputeAndCompareR0(&builder, initializer, {input_data.get()}); + } +}; + +XLA_TEST_F(ReduceInitializerTest, U8Small) { DoTest(42, 2); } + +XLA_TEST_F(ReduceInitializerTest, U8BigPowerOf2) { DoTest(42, 4096); } + +XLA_TEST_F(ReduceInitializerTest, U8InitializerBigNonPowerOf2) { + DoTest(42, 4095); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerZero) { + DoTest(0, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerOne) { + DoTest(1, 1024); +} + +XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { + DoTest(1234556789123, 1024); +} + +// Test the operational semantic that the init value is passed on the lhs for +// reduces. Can be tested by performing an "identity" reduce (that simply +// returns one of the parameters). In this case, we return the rhs, which for +// a 1D array with one element, should not be the init value. +XLA_TEST_F(ReduceTest, ReduceIdentity) { + ComputationBuilder builder(client_, TestName()); + Shape single_float = ShapeUtil::MakeShape(F32, {}); + builder.Parameter(0, single_float, "lhs-unused"); + builder.Parameter(1, single_float, "rhs-used"); + auto computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); + builder.Reduce(builder.Parameter(0, operand_shape, "operand"), + builder.Parameter(1, single_float, "init"), + computation_status.ValueOrDie(), {0}); + + float operand[] = {42.0f}; + float init = 58.5f; + float expected = 42.0f; + std::unique_ptr input_literal = Literal::CreateR1(operand); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr input_literal2 = Literal::CreateR0(init); + std::unique_ptr input_global_data2 = + client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + ComputeAndCompareR0( + &builder, expected, {input_global_data.get(), input_global_data2.get()}, + ErrorSpec(0.0001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index b11b64e40a582150d6adf29e915cd70b4bcb982b..0a097667222d8c315cb9943688d6dbcb4426a8b3 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -21,10 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -63,11 +64,9 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { class ReduceWindowTest : public ::testing::WithParamInterface, public ReduceWindowTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) { - set_use_bfloat16(GetParam()); - } + ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } - void ReduceWindowAdd(const ComputationDataHandle& input, + void ReduceWindowAdd(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -78,16 +77,17 @@ class ReduceWindowTest : public ::testing::WithParamInterface, window_dimensions, window_strides, padding); } - void ReduceWindowMax(const ComputationDataHandle& input, + void ReduceWindowMax(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); - builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, - window_strides, padding); + builder_.ReduceWindow(input, init, + CreateScalarMaxComputation(FloatType(), &builder_), + window_dimensions, window_strides, padding); } - void ReduceWindowMin(const ComputationDataHandle& input, + void ReduceWindowMin(const XlaOp& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { @@ -97,7 +97,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface, window_dimensions, window_strides, padding); } - ComputationBuilder builder_; + XlaBuilder builder_; }; TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { @@ -252,6 +252,48 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { DefaultErrorSpec()); } +// Tests the super windowing logic w.r.t handling prime number of windows in a +// major dimension with reduction. +TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { + Array4D input_array(15, 15, 4, 128); + input_array.FillRandom(2.f, 4.f); + + int win_len = 3; + int win_stride = 2; + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + // Reduce only along the x and y dimensions, according to the win_len. + ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {win_len, win_len, 1, 1}, + {win_stride, win_stride, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + +TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { + Array4D input_array(19, 17, 8, 256); + input_array.FillWithMinorDimNum(); + + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); + + Padding padding = Padding::kSame; + ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + auto result = ReferenceUtil::ReduceWindow4DAdd( + input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); +} + // Tests a reduction function that is not a simple add/min/max/etc. XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); @@ -268,7 +310,7 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto rhs = b->Parameter(1, scalar, "rhs"); b->Min(b->Add(lhs, rhs), CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); - Computation reduce_fn = b->BuildAndNoteError(); + XlaComputation reduce_fn = b->BuildAndNoteError(); builder_.ReduceWindow( input, @@ -296,7 +338,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -364,7 +406,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -386,7 +428,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -408,7 +450,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input; + XlaOp input; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "parameter", &builder_, &input); @@ -509,7 +551,7 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = builder_.Broadcast( + XlaOp input = builder_.Broadcast( CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; @@ -568,7 +610,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } void DoIt() { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -579,7 +621,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -920,7 +962,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, }; TEST_P(R3ReduceWindowTest, Add) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); @@ -931,7 +973,7 @@ TEST_P(R3ReduceWindowTest, Add) { Literal::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); auto init_value = @@ -960,45 +1002,73 @@ struct R2ReduceWindowTestData { int64 base_bounds[2]; int64 window_bounds[2]; int64 strides[2]; + int64 pad_low[2]; + int64 pad_high[2]; int64 layout[2]; - Padding padding; Reducer reducer; } kR2TestCases[] = { {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 2}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 2}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3}, - /*strides=*/{1, 1}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100}, - /*strides=*/{2, 99}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 99}, /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, +// TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a +// ptxas bug. +#ifndef XLA_TEST_BACKEND_GPU {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25}, - /*strides=*/{5, 4}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{5, 4}, /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, +#endif {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2}, - /*strides=*/{3, 3}, /*layout=*/{0, 1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{3, 3}, /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1}, + /*layout=*/{0, 1}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36}, - /*strides=*/{4, 5}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{4, 5}, /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, // Regression test for a bug that appeared in Inception (b/34784899). {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2}, - /*strides=*/{2, 2}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{2, 2}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + // Regression test for b/73903312: bf16 lacks precision to store result of + // very large windows. Testing with a reasonable window larger than 128. + {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 64}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4}, + /*strides=*/{1, 1024}, /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0}, + /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd}, }; string R2ReduceWindowTestDataToString( @@ -1008,10 +1078,11 @@ string R2ReduceWindowTestDataToString( string str = tensorflow::strings::StrCat( "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__padding_", param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", param.layout[0], "_", param.layout[1], // + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__layout_", param.layout[0], "_", param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { str = tensorflow::strings::StrCat(str, "_bfloat16"); @@ -1026,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } void DoIt() { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); @@ -1036,20 +1107,32 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); + std::vector> padding(2); + for (int i = 0; i < 2; ++i) { + padding[i] = {param.pad_low[i], param.pad_high[i]}; + } + auto computation = param.reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); auto init_value = CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); - b.ReduceWindow(/*operand=*/parameter, - /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + auto reduce_func = param.reducer == kAdd + ? +[](float a, float b) { return a + b; } + : +[](float a, float b) { return std::max(a, b); }; + auto expected = ReferenceUtil::ReduceWindow2DGeneric( + /*operand=*/input, /*init=*/kInitValue, /*reduce_func=*/reduce_func, + /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/padding); ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); @@ -1074,8 +1157,9 @@ XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, - /*strides=*/{1, 1}, /*layout=*/{1, 0}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*strides=*/{1, 1}, /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0}, + /*layout=*/{1, 0}, + /*reducer=*/Reducer::kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -1211,7 +1295,7 @@ class R1ReduceWindowTest : public ReduceWindowTestBase, }; TEST_P(R1ReduceWindowTest, DoIt) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); @@ -1220,7 +1304,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - ComputationDataHandle parameter; + XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1315,5 +1399,58 @@ ENTRY R2Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } +TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[1,1]{1,0} parameter(0) + negate = f32[1,1]{1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { + const string& hlo_string = R"( +HloModule R3Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R3Window { + operand = f32[1,1,1]{2,1,0} parameter(0) + negate = f32[1,1,1]{2,1,0} negate(operand) + constant = f32[] constant(1) + ROOT reduce-window = f32[1,1,1]{2,1,0} reduce-window(negate, constant), window={size=1x1x1 pad=0_0x0_0x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(HloTestBase, ReduceWindowIdentity) { + const string& hlo_string = R"( +HloModule ReduceWindowIdentity +identity.pad_to_reduce_window { + param0 = f32[] parameter(0) + ROOT param1 = f32[] parameter(1) +} +ENTRY reduce-window-identity { + operand = f32[1,32,64]{2,1,0} parameter(0) + constant.4466 = f32[] constant(0) + ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index f7b04debd4f5c40a904e32c832b6fc384a03c33b..d7462d581b8596dc43b81b0162b3f5020cebb546 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -52,11 +52,11 @@ class ReshapeTest : public ::testing::WithParamInterface, // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -67,9 +67,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { } XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); @@ -80,9 +80,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { } XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); @@ -94,11 +94,11 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(1, 1); input_array.Fill(1.0f); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", &builder, ¶meter); auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -111,15 +111,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); auto a = builder.Neg(parameter); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); + builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = Literal::CreateR1({-1.0f}); ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, @@ -130,10 +129,10 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(0, 3); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -146,11 +145,11 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -163,10 +162,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input_array(3, 0); auto input_literal = Literal::CreateR2FromArray2D(input_array); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -177,9 +176,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { // Collapses a 2-dimensional row vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -190,9 +189,9 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { // Collapses a 2-dimensional column vector to 1 dimension. XLA_TEST_P(ReshapeTest, Trivial3x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); @@ -207,9 +206,9 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { // // Splits an empty vector into an empty matrix. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -221,10 +220,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { // Splits a vector into a matrix. XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, @@ -241,9 +240,9 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { // // Transposes a 2x0 array to a 0x2 array. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -255,10 +254,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { // Transposes a 2-dimensional row vector to a column vector. XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = Literal::CreateFromArray(*simple); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -272,10 +271,10 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { // Transposes a 2-dimensional array. XLA_TEST_P(ReshapeTest, TransposeAsReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -291,11 +290,11 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. // -// Transposes a 0x4 array with ComputationBuilder::Trans. +// Transposes a 0x4 array with XlaBuilder::Transpose. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -306,10 +305,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { // Transposes a 2-dimensional array with ComputationBuilder::Trans. XLA_TEST_P(ReshapeTest, Transpose4x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Transpose(parameter, {1, 0}); @@ -327,9 +326,9 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -343,9 +342,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { // does not handle zero-sized shapes correctly. Failed last on 2017-11-30 // with an incorrect result rank. XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, @@ -358,10 +357,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, @@ -378,9 +377,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { // with an incorrect result rank. // XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -393,10 +392,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = Literal::CreateFromArray(*a4x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, @@ -420,9 +419,9 @@ static Array3D ArrayForDocR3Tests() { } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, @@ -435,9 +434,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { } XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, @@ -455,9 +454,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { } XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, @@ -470,9 +469,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { } XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, @@ -490,9 +489,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { } XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, @@ -520,12 +519,12 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // // 1 2 3 4 5 6 1 2 3 4 5 6 XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = Literal::CreateFromArray(t2x2x2x3); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); @@ -539,7 +538,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { // As above, but uses reshape directly. XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; t(0, 0, 0, 1) = 1; @@ -550,7 +549,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; auto input_literal = Literal::CreateFromArray(t); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, @@ -565,7 +564,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { // Reshape various ranks to a scalar. XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -573,7 +572,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { std::vector zeros(rank, 0); // this is {0, ..., 0}. input_literal.Set(zeros, 83.0f); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); b.Reshape(parameter, dimensions, {}); @@ -585,9 +584,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) { } XLA_TEST_P(ReshapeTest, BadDimensions) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input_literal = Literal::CreateR1({1.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); b.Reshape(parameter, {}, {}); @@ -597,9 +596,9 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { } XLA_TEST_P(ReshapeTest, BadNewSizes) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input_literal = Literal::CreateR1({1.0f, 2.0f}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, ¶meter); b.Reshape(parameter, {1}, {}); @@ -608,7 +607,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { } XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { @@ -634,7 +633,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); @@ -645,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { {222, 333, 444, 555, 666, 777, 888, 999}, }); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, @@ -663,13 +662,13 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -690,13 +689,13 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -716,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { } XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -726,7 +725,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); @@ -738,7 +737,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { } XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -748,7 +747,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); @@ -761,7 +760,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -771,7 +770,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, @@ -788,7 +787,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { } XLA_TEST_P(ReshapeTest, NoopReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -798,12 +797,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - ComputationDataHandle parameter; + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = @@ -826,12 +825,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { } XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationDataHandle parameter; + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, @@ -845,8 +844,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, @@ -879,8 +878,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, @@ -908,8 +907,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, @@ -937,8 +936,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, @@ -967,8 +966,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, @@ -996,8 +995,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle parameter; + XlaBuilder builder(TestName()); + XlaOp parameter; auto input_data = CreateParameterAndTransferLiteral( 0, *input_literal, "input", &builder, ¶meter); builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 8fc841f14087cdea02fe44cdaea521ff92122aec..6959c95502cb7af6b720592e7836c6789719a528 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -85,7 +85,7 @@ TEST_P(FloatReverseTest, Reverses) { auto r1_literal = Literal::CreateR1(input_vector); auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = AddParam(*input_literal, &builder); builder.Rev(a, spec.reversal); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 4da6ee91607941b395b00befc98a10e7c17746ed..0c88bef69dfc522fef52422b0bd3a825fa173d44 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -163,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); builder.ConvertElementType(a, F32); - int64 value = 3LL << 32; + int64 value = 3LL << 35; std::unique_ptr a_literal = Literal::CreateR0(value); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -860,6 +860,12 @@ XLA_TEST_F(ScalarComputationsTest, MinF32Below) { TestMinMax(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); } +XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { + SetFastMathDisabled(true); + TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Min); + TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Min); +} + XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { TestMinMax(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); } @@ -868,6 +874,12 @@ XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { TestMinMax(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); } +XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { + SetFastMathDisabled(true); + TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Max); + TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Max); +} + XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. ComputationBuilder b(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 9ee94b8571e5fc8789b60501462986967ce909a0..7015e5a6a31f506d30c2629d7735482cf354455a 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -50,7 +50,7 @@ class SelectAndScatterTest : public ClientLibraryTestBase, public ::testing::WithParamInterface { public: - SelectAndScatterTest() : builder_(client_, TestName()) { + SelectAndScatterTest() : builder_(TestName()) { // Create S32 GE and ADD computations for select and scatter respectively. ge_s32_ = CreateScalarGeComputation(S32, &builder_); add_s32_ = CreateScalarAddComputation(S32, &builder_); @@ -60,13 +60,13 @@ class SelectAndScatterTest min_f32_ = CreateScalarMinComputation(F32, &builder_); } - ComputationBuilder builder_; - Computation ge_s32_; - Computation add_s32_; - Computation ge_f32_; - Computation add_f32_; - Computation max_f32_; - Computation min_f32_; + XlaBuilder builder_; + XlaComputation ge_s32_; + XlaComputation add_s32_; + XlaComputation ge_f32_; + XlaComputation add_f32_; + XlaComputation max_f32_; + XlaComputation min_f32_; }; XLA_TEST_P(SelectAndScatterTest, ParamTest) { @@ -80,12 +80,11 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) { s.FillRandom(12.0f); auto source = builder_.ConstantFromArray(s); - auto select_and_scatter = builder_.SelectAndScatter( - operand, ge_f32_, GetParam().window_dimensions, GetParam().window_strides, - GetParam().padding_type, source, builder_.ConstantR0(0.0f), - add_f32_); + builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, + GetParam().window_strides, GetParam().padding_type, + source, builder_.ConstantR0(0.0f), add_f32_); - ComputeAndCompare(&builder_, select_and_scatter, {}, ErrorSpec(1e-5)); + ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5)); } INSTANTIATE_TEST_CASE_P( @@ -252,6 +251,21 @@ XLA_TEST_F(SelectAndScatterTest, R2S32) { ComputeAndCompareR2(&builder_, expected, {}); } +// Test for tie breaking rule in ge_f32_. When a tie is present, the operand +// that has the lower lexicographical order (smaller index) should be chosen. +XLA_TEST_F(SelectAndScatterTest, R2F32Tie) { + const auto operand = builder_.ConstantR2( + {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}}); + const auto source = builder_.ConstantR2( + {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}}); + Array2D expected( + {{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}}); + builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3}, + /*window_strides=*/{1, 1}, Padding::kSame, source, + builder_.ConstantR0(0.0f), add_f32_); + ComputeAndCompareR2(&builder_, expected, {}, ErrorSpec(1e-7)); +} + // Similar to SelectAndScatterTest.R2S32 but the input is transposed. XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) { const auto operand = builder_.ConstantR2( diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index ac163df127e0087c02777fa3d5ce7970c51b97b9..52195db2aa74710b901dd7744a670764a034e96b 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -41,7 +41,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1}); @@ -54,7 +54,7 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1}); @@ -67,7 +67,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { Array3D values(3, 3, 3); values.FillIota(0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR3FromArray3D(values); builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1}); @@ -77,7 +77,7 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) { } XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 0)); builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); @@ -85,7 +85,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { } XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(0, 20)); builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); @@ -93,7 +93,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { } XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(Array2D(3, 0)); builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); @@ -108,7 +108,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); @@ -126,7 +126,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { Array2D values(1, 4096); std::iota(values.data(), values.data() + 4096, 0.0); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); @@ -147,7 +147,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR2FromArray2D(values); builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); ComputeAndCompareR2(&builder, expected, {}, ErrorSpec(0.000001)); @@ -159,7 +159,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); @@ -172,7 +172,7 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { /*strides=*/{{1, 1, 2, 1}}); auto expected_literal = Literal::CreateR4FromArray4DWithLayout( *expected, LayoutUtil::MakeLayout({0, 1, 2, 3})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), @@ -193,15 +193,18 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - std::vector input(spec.input_dim0); + // This can't be an std::vector, since you can't grab an ArraySlice of a + // vector. + tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto original = builder.ConstantR1(input); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); - std::vector expected; + // Ditto. + tensorflow::gtl::InlinedVector expected; for (int i = spec.slice_start; i < spec.slice_limit; i += spec.slice_stride) { expected.push_back(i); @@ -211,6 +214,9 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +// A version of SliceR1Test used to label and disable 'large' tests +class SliceR1LargeTest : public SliceR1Test {}; + string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, @@ -230,6 +236,21 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } +XLA_TEST_P(SliceR1LargeTest, DoIt_F32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_F64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_U32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_S32) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_U64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1LargeTest, DoIt_S64) { Run(GetParam()); } + +XLA_TEST_P(SliceR1Test, DoIt_PRED) { Run(GetParam()); } + + // Tests for R1 slice ops. // The format for each testcase is {input size, start, limit, stride}. // clang-format off @@ -267,13 +288,32 @@ INSTANTIATE_TEST_CASE_P( R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, - R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1} + ), + SliceR1TestDataToString +); + // TODO(b/69425338): This uses too much memory on GPU. #ifndef XLA_TEST_BACKEND_GPU - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +INSTANTIATE_TEST_CASE_P( + SliceR1TestBigSlicesInstantiation, + SliceR1LargeTest, + ::testing::Values( + R1Spec{ + 16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{ + 16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{ + 16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1} + ), + SliceR1TestDataToString +); #endif + +INSTANTIATE_TEST_CASE_P( + SliceStridedR1TestInstantiation, + SliceR1Test, + ::testing::Values( R1Spec{10, 2, 4, 2}, R1Spec{10, 0, 10, 2}, R1Spec{10, 0, 10, 3}, @@ -285,8 +325,24 @@ INSTANTIATE_TEST_CASE_P( R1Spec{2047, 1024 - 24, 1024 + 160, 31}, R1Spec{2047, 1, 2046, 3 * 128}, R1Spec{4096, 1024 + 3, 4095, 500}, - R1Spec{8192, 0, 8192, 1024 * 3 + 400} - ), + R1Spec{8192, 0, 8192, 1024 * 3 + 400}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 2}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 8}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 7}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 125}, + R1Spec{1024 * 1024, 3, 1024 - 9, 2}, + R1Spec{1024 * 1024, 3, 1024 - 9, 8}, + R1Spec{1024 * 1024, 3, 1024 - 9, 7}, + R1Spec{1024 * 1024, 3, 1024 - 9, 125}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 2}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 8}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 7}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 125}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125} + ), SliceR1TestDataToString ); // clang-format on @@ -310,7 +366,7 @@ XLA_TEST_P(SliceR2Test, DoIt) { Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(spec.layout)); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); @@ -400,7 +456,7 @@ class SliceR4Test : public ClientLibraryTestBase, values.FillRandom(3.14f); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto literal = Literal::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); auto parameter = builder.Parameter(0, literal->shape(), "p0"); diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 978a669bcab720bddec5c4bcd0144810ba3c8477..be35ec6c6ee4c015755622b2dc9bb92e23af7c85 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index b060fb13b1451aab30cfca73bea0a4a598a9fa3a..cda1989fad670c805f30b5043e342d5f9a9a6fe2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -23,14 +23,14 @@ namespace xla { namespace { -template -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { +template +void PopulateWithRandomFloatingPointDataImpl(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. - std::uniform_real_distribution generator(1.0f, 1.125f); + std::uniform_real_distribution generator(1.0f, 1.125f); const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice indices) { @@ -52,10 +52,22 @@ void PopulateWithRandomFloatingPointData(Literal* literal, FloatT index_bias = static_cast(index_product % 113 - negative_bias) / static_cast(256.0f); - return (generator(*engine) - 1.0625) + index_bias; + return static_cast(generator(*engine) - 1.0625f) + index_bias; })); } +template +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl(literal, engine); +} + +template <> +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl(literal, engine); +} + // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> @@ -100,6 +112,9 @@ StatusOr> MakeFakeLiteralInternal( case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); break; + case F16: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; case F32: PopulateWithRandomFloatingPointData(literal.get(), engine); break; @@ -145,27 +160,38 @@ StatusOr> MakeFakeLiteralInternal( return std::move(literal); } -// Matches binary addition computations. -bool LooksLikeSum(const HloComputation& computation) { +enum class ConstantType { kUnknown, kZero, kOne }; + +// Return the constant type required by this computation, if known. +ConstantType GetInitValue(const HloComputation& computation) { const HloInstruction* const root = computation.root_instruction(); - return root->opcode() == HloOpcode::kAdd && - computation.num_parameters() == 2 && - root->operand(0)->opcode() == HloOpcode::kParameter && - root->operand(1)->opcode() == HloOpcode::kParameter && - root->operand(0) != root->operand(1); + if (computation.num_parameters() != 2 || root->operand_count() != 2 || + root->operand(0)->opcode() != HloOpcode::kParameter || + root->operand(1)->opcode() != HloOpcode::kParameter || + root->operand(0) == root->operand(1)) { + return ConstantType::kUnknown; + } + + switch (root->opcode()) { + case HloOpcode::kAdd: + return ConstantType::kZero; + case HloOpcode::kMultiply: + return ConstantType::kOne; + default: + return ConstantType::kUnknown; + } } -// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, -// which requires an init_value of 0 rather than a random value. -bool NeedsZeroInitValue(const HloUse& use) { +// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random +// initialization value. +bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; return ( ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1 && LooksLikeSum(*instruction->to_apply())) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && - LooksLikeSum(*instruction->scatter()))); + op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); } // Generate random values that are constrained to the input_shape minus the @@ -207,7 +233,7 @@ std::vector FindConstrainedUses( auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), fused_uses.end()); - } else if (NeedsZeroInitValue(use)) { + } else if (NeedsInitValue(use)) { constrained_uses.push_back(instruction); } else if (opcode == HloOpcode::kConvert || opcode == HloOpcode::kReducePrecision) { @@ -228,7 +254,8 @@ StatusOr> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { HloInstruction* needs_index = nullptr; - HloInstruction* needs_zero = nullptr; + HloInstruction* needs_constant = nullptr; + ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: @@ -243,8 +270,13 @@ StatusOr> CreateLiteralForConstrainedUses( case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + needs_constant = use; + constant_type = GetInitValue(*use->to_apply()); + break; + case HloOpcode::kSelectAndScatter: - needs_zero = use; + needs_constant = use; + constant_type = GetInitValue(*use->scatter()); break; default: @@ -253,17 +285,26 @@ StatusOr> CreateLiteralForConstrainedUses( use->ToString().c_str()); } } - if (needs_index != nullptr && needs_zero != nullptr) { + if (needs_index != nullptr && needs_constant != nullptr) { return Unimplemented( "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "zero: %s\n", - needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + "constant: %s\n", + needs_index->ToString().c_str(), needs_constant->ToString().c_str()); } if (needs_index != nullptr) { return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), needs_index->shape(), engine); - } else if (needs_zero != nullptr) { - return Literal::CreateFromShape(param.shape()); + } else if (needs_constant != nullptr) { + switch (constant_type) { + case ConstantType::kZero: + return Literal::Zero(param.shape().element_type()).CloneToUnique(); + case ConstantType::kOne: + return Literal::One(param.shape().element_type()).CloneToUnique(); + case ConstantType::kUnknown: + // We want the identity element for the computation, but we don't really + // know what it is - so any value we generate will be just as wrong. + return MakeFakeLiteralInternal(param.shape(), engine); + } } else { return MakeFakeLiteralInternal(param.shape(), engine); } @@ -287,7 +328,7 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { StatusOr>> MakeFakeArguments( HloModule* const module) { - TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); std::minstd_rand0 engine; std::vector> arguments(params.size()); @@ -299,8 +340,8 @@ StatusOr>> MakeFakeArguments( } Status VerifyHloModule(const perftools::gputools::Platform& platform, - HloModule* const module) { - return HloVerifier().Run(module).status(); + HloModule* const module, bool allow_mixed_precision) { + return HloVerifier(allow_mixed_precision).Run(module).status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2..b5ab779574fd5237d14cd24c345a9d5f1d41d1fd 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -69,7 +69,8 @@ StatusOr>> MakeFakeArguments( // Check that a given module satisfies various constraints before trying to // execute it. Status VerifyHloModule(const perftools::gputools::Platform& platform, - HloModule* const module); + HloModule* const module, + bool allow_mixed_precision = false); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e8efc6e2a83f42bf81fc1261ba508632cf3f85b3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_utils.h" + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +// A test fixture is used because we need a client for our computation builder. +class TestUtilsTest : public LocalClientTestBase {}; + +XLA_TEST_F(TestUtilsTest, UnusedParam) { + ComputationBuilder builder(local_client_, TestName()); + // Make the reduction lambda. + Shape single_float = ShapeUtil::MakeShape(F32, {}); + builder.Parameter(0, single_float, "unused"); + builder.Parameter(1, single_float, "used"); + auto computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + // Make the reduction. + Shape pair_float = ShapeUtil::MakeShape(F32, {2}); + builder.Reduce(builder.Parameter(0, pair_float, "operand"), + builder.Parameter(1, single_float, "init"), + computation_status.ValueOrDie(), {0}); + computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + auto executable_status = local_client_->Compile( + computation_status.ValueOrDie(), {&pair_float, &single_float}, + ExecutableBuildOptions()); + TF_ASSERT_OK(executable_status.status()); + HloModule& module = const_cast( + executable_status.ValueOrDie()->executable()->module()); + TF_ASSERT_OK(MakeFakeArguments(&module).status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 2029312f94a14bc81706368b9ecfc2727fd9fe4c..098be6d7aabe88d0deef600716229ddbd0bcae2f 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -20,11 +20,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,7 +43,7 @@ class TupleTest : public ClientLibraryTestBase { // Tests a tuple-shaped constant. XLA_TEST_F(TupleTest, TupleConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -53,13 +56,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { Literal::CreateR1(constant_vector).get(), Literal::CreateR2(constant_matrix).get()}); - auto result = builder.ConstantLiteral(*value); + builder.ConstantLiteral(*value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } // Tests a tuple made of scalar constants. XLA_TEST_F(TupleTest, TupleScalarConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; @@ -67,13 +70,13 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), Literal::CreateR0(constant_scalar2).get()}); - auto result = builder.ConstantLiteral(*value); + builder.ConstantLiteral(*value); ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; @@ -81,9 +84,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto result = builder.Tuple({builder.ConstantR0(constant_scalar), - builder.ConstantR1(constant_vector), - builder.ConstantR2(constant_matrix)}); + builder.Tuple({builder.ConstantR0(constant_scalar), + builder.ConstantR1(constant_vector), + builder.ConstantR2(constant_matrix)}); auto expected = Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), @@ -94,9 +97,9 @@ XLA_TEST_F(TupleTest, TupleCreate) { // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - auto result = builder.Tuple( + builder.Tuple( {builder.ConstantR0(7.0), builder.ConstantR1({})}); auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), @@ -106,15 +109,15 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { // Tests the creation of an empty tuple. XLA_TEST_F(TupleTest, EmptyTupleCreate) { - ComputationBuilder builder(client_, TestName()); - auto result = builder.Tuple({}); + XlaBuilder builder(TestName()); + builder.Tuple({}); auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElement) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -122,23 +125,23 @@ XLA_TEST_F(TupleTest, GetTupleElement) { }; auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + builder.GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(constant_matrix), {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto tuple_data = builder.Tuple( {builder.ConstantR1({}), builder.ConstantR2FromArray2D(Array2D(0, 101))}); - auto matrix_element = builder.GetTupleElement(tuple_data, 1); + builder.GetTupleElement(tuple_data, 1); ComputeAndCompareR2(&builder, Array2D(0, 101), {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto value = builder.ConstantR1({4.5f}); builder.GetTupleElement(value, 1); auto result_status = builder.Build(); @@ -151,7 +154,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) { // Extracts both elements from a tuple with GetTupleElement and then adds them // together. XLA_TEST_F(TupleTest, AddTupleElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -163,22 +166,22 @@ XLA_TEST_F(TupleTest, AddTupleElements) { auto matrix_element = builder.GetTupleElement(tuple_data, 1); auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie(); auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie(); - auto result = builder.Add(matrix_element, vector_element, - /*broadcast_dimensions=*/{1}); + builder.Add(matrix_element, vector_element, + /*broadcast_dimensions=*/{1}); Array2D expected({ {2.f, 4.f, 6.f}, // row 0 {5.f, 7.f, 9.f}, // row 1 }); - ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3})); - ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3})); + ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3})); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } // Extracts both elements from a tuple and then puts them into a new tuple in // the opposite order. XLA_TEST_F(TupleTest, TupleGTEToTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -186,8 +189,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { }; auto tuple_data = builder.Tuple({builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), - builder.GetTupleElement(tuple_data, 0)}); + builder.Tuple({builder.GetTupleElement(tuple_data, 1), + builder.GetTupleElement(tuple_data, 0)}); auto expected = Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), Literal::CreateR1(constant_vector).get()}); @@ -195,8 +198,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle v1, v2; + XlaBuilder b(TestName()); + XlaOp v1, v2; for (bool direction : {false, true}) { std::unique_ptr v1_data = @@ -209,7 +212,7 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v2_gt = b.Gt(v2, v1); // true auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} - auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); auto expected = Literal::MakeTuple({Literal::CreateR0(direction).get(), Literal::CreateR0(!direction).get()}); @@ -236,7 +239,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { // \ (tuple10)-- / // \ / \ / // -----(GTE 0)-- --(GTE 1)---------- - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list constant_vector = {1.f, 2.f, 3.f}; std::initializer_list> constant_matrix = { {1.f, 2.f, 3.f}, // row 0 @@ -256,8 +259,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { auto addvectors = builder.Add(vector_from_01, vector_from_10); auto addmatrices = builder.Add(matrix_from_01, matrix_from_10); - auto result = builder.Add(addmatrices, addvectors, - /*broadcast_dimensions=*/{1}); + builder.Add(addmatrices, addvectors, + /*broadcast_dimensions=*/{1}); Array2D expected({ {4.f, 8.f, 12.f}, // row 0 @@ -268,7 +271,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { // Tests a selection between tuples with "false" path taken. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -277,8 +280,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { auto tuple21 = builder.Tuple( {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + builder.Select(builder.ConstantR0(false), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -313,7 +315,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { // Tests a selection between tuples with "true" path taken. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -322,8 +324,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { auto tuple21 = builder.Tuple( {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); - auto select = - builder.Select(builder.ConstantR0(true), tuple12, tuple21); + builder.Select(builder.ConstantR0(true), tuple12, tuple21); auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); @@ -332,7 +333,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { // Tests a selection between tuples but the final result is an element of the // tuple, not the whole tuple. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -343,7 +344,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto element = builder.GetTupleElement(select, 0); + builder.GetTupleElement(select, 0); ComputeAndCompareR1(&builder, vec2, {}, error_spec_); } @@ -367,7 +368,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { // / --(GTE 1)-- // / // (tuple 21) - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -383,8 +384,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21); auto select2 = builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1); - auto result = builder.Add(builder.GetTupleElement(select2, 0), - builder.GetTupleElement(select2, 1)); + builder.Add(builder.GetTupleElement(select2, 0), + builder.GetTupleElement(select2, 1)); ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } @@ -393,7 +394,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) { // Similar to SelectBetweenTuples, but the constants are shared between the // input tuples. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -402,19 +403,18 @@ XLA_TEST_F(TupleTest, auto tuple12 = builder.Tuple({c1, c2}); auto tuple21 = builder.Tuple({c2, c1}); - auto select = - builder.Select(builder.ConstantR0(false), tuple12, tuple21); + builder.Select(builder.ConstantR0(false), tuple12, tuple21); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto inner_tuple = builder.Tuple( {builder.ConstantR1({1.0, 2.0}), builder.ConstantR0(42.0)}); - auto outer_tuple = - builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); + builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); auto expected_v1 = Literal::CreateR1({1.0, 2.0}); auto expected_s = Literal::CreateR0(42.0); @@ -428,7 +428,7 @@ XLA_TEST_F(TupleTest, NestedTuples) { } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {3}); Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape}); @@ -459,7 +459,7 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { } XLA_TEST_F(TupleTest, ComplexTuples) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); { Shape c64r0 = ShapeUtil::MakeShape(C64, {}); Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); @@ -514,5 +514,33 @@ XLA_TEST_F(TupleTest, ComplexTuples) { error_spec_); } +class TupleHloTest : public HloTestBase {}; + +// Disabled on CPU parallel because that's broken and will be removed soon. +// Disabled on the interpreter because bitcast doesn't exist on the interpreter. +TEST_F(TupleHloTest, + DISABLED_ON_INTERPRETER(DISABLED_ON_CPU_PARALLEL(BitcastAfterGTE))) { + const char* testcase = R"( + HloModule m + + ENTRY test { + name.1 = (f32[3]{0}) parameter(0) + get-tuple-element.1 = f32[3]{0} get-tuple-element(name.1), index=0 + bitcast = f32[1,3]{1,0} bitcast(get-tuple-element.1) + copy = f32[1,3]{1,0} copy(bitcast) + ROOT tuple.4 = (f32[1,3]{1,0}) tuple(copy) + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); + TF_ASSERT_OK_AND_ASSIGN(auto result, + ExecuteNoHloPasses(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 52157b837c383205f77a030ef98b2fd03a41aff5..89ce2ce797f979b8668fbdb172a4a3abc5922b9f 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -54,29 +54,28 @@ TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -91,29 +90,28 @@ TEST_F(WhileTest, WhileWithScalarS64Result) { auto result_shape = ShapeUtil::MakeShape(S64, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -123,31 +121,30 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto orig_shape = ShapeUtil::MakeShape(S32, {2}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Gt(builder.ConstantR0(5), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Reduce(builder.ConstantR1(2, 1), builder.ConstantR0(0), CreateScalarAddComputation(S32, &builder), {0}); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -156,28 +153,28 @@ TEST_F(WhileTest, WhileWithPredicateResult) { auto result_shape = ShapeUtil::MakeShape(PRED, {}); // Create a computation for the condition: run until condition is true. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Ne(builder.ConstantR0(true), prev); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body: or condition with true. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.Or(prev, builder.ConstantR0(true)); + builder.Or(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true)); - auto result = builder.While(condition, body, init); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, true, {}); } @@ -194,9 +191,9 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {0}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -205,33 +202,34 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 15.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1({}); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1({}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.0001)); } @@ -247,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -258,33 +256,34 @@ TEST_F(WhileTest, WhileWithVectorResult) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); // Individual elements with increase by 1/8 each time through the loop, so // the sum will increase by 1.0. It will first be >15.5 when the elements @@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { Shape result_shape = ShapeUtil::MakeShape(F32, {8}); // Create a computation for the reduction. - Computation add; + XlaComputation add; { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -317,34 +316,34 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // Create a computation for the condition. // Repeat until the sum of the result vector is less than 5.5f. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto sum = builder.Reduce(prev, builder.ConstantR0(0.0f), add, /*dimensions_to_reduce=*/{0}); - auto test = builder.Gt(builder.ConstantR0(15.5f), sum); + builder.Gt(builder.ConstantR0(15.5f), sum); condition = builder.Build().ConsumeValueOrDie(); } // Create a computation for the body. // Add a constant vector of 1.f to the result vector. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR1(8, 0.125f); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.ConstantR1(8, 0.f); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); builder.Tuple({result}); // Individual elements with increase by 1/8 each time through the loop, so @@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(N), iteration); @@ -377,28 +376,28 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(N); auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); @@ -419,9 +418,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the condition. // Repeat for N iterations. const int N = 2; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(N), iteration); @@ -430,21 +429,21 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { // Create a computation for the body. // Add 1 to the iteration variable permute the weights. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto w1 = builder.GetTupleElement(prev, 1); auto w2 = builder.GetTupleElement(prev, 2); auto w3 = builder.GetTupleElement(prev, 3); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); @@ -455,7 +454,7 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); std::vector expected = {6.f, 6.f, 6.f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } @@ -474,9 +473,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -486,26 +485,27 @@ TEST_F(WhileTest, WhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); - VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( @@ -523,9 +523,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -534,27 +534,27 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and or the predicate with true - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); auto new_pred = builder.Or(pred, builder.ConstantR0(true)); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple({builder.ConstantR0(0), builder.Ne(builder.ConstantR0(false), builder.ConstantR0(true))}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_predicate = Literal::CreateR0(true); @@ -570,9 +570,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -582,25 +582,24 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { // Create a computation for the body. // Add 1 to the iteration variable and set the other tuple element to a // constant. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); - auto result = - builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), - builder.ConstantR0(7)}); + builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), + builder.ConstantR0(7)}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR0(7)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR0(7); @@ -631,20 +630,20 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -654,34 +653,34 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - Computation body2; + XlaComputation body2; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -692,11 +691,11 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -710,20 +709,20 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -733,21 +732,21 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -758,11 +757,11 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -777,20 +776,20 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; const int c1 = 5; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c1)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation condition2; + XlaComputation condition2; const int c2 = 7; { - ComputationBuilder builder(client_, "condition2"); + XlaBuilder builder("condition2"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(c2)); @@ -800,21 +799,21 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto weights = builder.GetTupleElement(prev, 1); auto input = builder.ConstantR1(10, 1.f); auto new_weights = builder.Add(weights, input); - auto result = builder.Tuple( + builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_weights}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto while1 = builder.While(condition, body, init); @@ -824,11 +823,11 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { auto while_result2 = builder.GetTupleElement(while2, 1); VLOG(2) << "while_result2 = " << ShapeUtil::HumanString( - *builder.GetShape(while_result2).ConsumeValueOrDie()); + builder.GetShape(while_result2).ConsumeValueOrDie()); auto result = builder.Add(while_result1, while_result2); VLOG(2) << "result = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); const float sum = c1 + c2; std::vector expected(10, sum); ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); @@ -844,9 +843,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the condition. // Repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Gt(builder.ConstantR0(5), iteration); @@ -856,9 +855,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Create a computation for the body. // Add 1 to the iteration variable and add a constant vector of 1.0f to // the weight variable, both of which are tuple elements. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -873,18 +872,18 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, "while"); + XlaBuilder builder("while"); auto init = builder.Tuple( {builder.ConstantR0(0), builder.ConstantR1(10, 0.f)}); auto result = builder.While(condition, body, init); VLOG(2) << "while = " << ShapeUtil::HumanString( - *builder.GetShape(result).ConsumeValueOrDie()); + builder.GetShape(result).ConsumeValueOrDie()); auto expected_counter = Literal::CreateR0(5); auto expected_data = Literal::CreateR1( @@ -910,23 +909,23 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { // Per backend the values generated can be different as the different backends // use different random number generators. // TODO(b/32240857): Extend test to verify outputs. -TEST_F(WhileTest, WhileWithPrngScalarResult) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { auto v6s32 = ShapeUtil::MakeShape(S32, {6}); // Create a computation for the condition: repeat for count iterations. auto build_condition = [this, v6s32](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prev = builder.Reshape( builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0}, - {}); + {}); builder.Gt(builder.ConstantR0(count), prev); return builder.Build().ConsumeValueOrDie(); }; // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, v6s32, "prev"); auto inc = builder.ConcatInDim( {builder.ConstantR1({1}), @@ -934,16 +933,15 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { builder.ConstantR0(100), ShapeUtil::MakeShape(S32, {5}))}, 0); - auto result = builder.Add(inc, prev); + builder.Add(inc, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. auto while_loop = [this, &body, build_condition](int count) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR1({0, 0, 0, 0, 0, 0}); - auto result = builder.While(build_condition(count), body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(build_condition(count), body, init); return builder.Build(); }; @@ -1107,9 +1105,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { auto inner_result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - Computation inner_condition; + XlaComputation inner_condition; { - ComputationBuilder builder(client_, "inner_condition"); + XlaBuilder builder("inner_condition"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); builder.Lt(i, builder.ConstantR0(7)); @@ -1118,9 +1116,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the outer loop condition: // repeat while result < 30. - Computation outer_condition; + XlaComputation outer_condition; { - ComputationBuilder builder(client_, "outer_condition"); + XlaBuilder builder("outer_condition"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); builder.Lt(prev, builder.ConstantR0(30)); outer_condition = builder.Build().ConsumeValueOrDie(); @@ -1128,34 +1126,33 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to // `result`. - Computation inner_body; + XlaComputation inner_body; { - ComputationBuilder builder(client_, "inner_body"); + XlaBuilder builder("inner_body"); auto params = builder.Parameter(0, inner_result_shape, "prev"); auto i = builder.GetTupleElement(params, 0); auto result = builder.GetTupleElement(params, 1); i = builder.Add(builder.ConstantR0(1), i); result = builder.Add(builder.ConstantR0(2), result); - auto output = builder.Tuple({i, result}); + builder.Tuple({i, result}); inner_body = builder.Build().ConsumeValueOrDie(); } // Creates a computation for the outer loop: run the inner loop with i = 0. - Computation outer_body; + XlaComputation outer_body; { - ComputationBuilder builder(client_, "outer_body"); + XlaBuilder builder("outer_body"); auto prev = builder.Parameter(0, outer_result_shape, "prev"); auto init = builder.Tuple({builder.ConstantR0(0), prev}); auto result = builder.While(inner_condition, inner_body, init); - auto output = builder.GetTupleElement(result, 1); + builder.GetTupleElement(result, 1); outer_body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(outer_condition, outer_body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(outer_condition, outer_body, init); ComputeAndCompareR0(&builder, 42, {}); } @@ -1166,22 +1163,22 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) { // while (f(result).get<0>()) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithCallInsideCondition) { +TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition_callee; + XlaComputation condition_callee; { - ComputationBuilder builder(client_, "condition_callee"); + XlaBuilder builder("condition_callee"); auto prev = builder.Parameter(0, result_shape, "prev"); builder.Tuple({builder.Gt(builder.ConstantR0(5), prev)}); condition_callee = builder.Build().ConsumeValueOrDie(); } - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, result_shape, "prev"); auto result = builder.Call(condition_callee, {prev}); builder.GetTupleElement(result, 0); @@ -1189,20 +1186,19 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) { } // Create a computation for the body: add 1 to the result variable. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, result_shape, "prev"); auto input = builder.ConstantR0(1); - auto result = builder.Add(input, prev); + builder.Add(input, prev); body = builder.Build().ConsumeValueOrDie(); } // Create a While node with computations for the condition and the body. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto init = builder.ConstantR0(0); - auto result = builder.While(condition, body, init); - auto shape = builder.GetShape(result).ConsumeValueOrDie(); + builder.While(condition, body, init); ComputeAndCompareR0(&builder, 5, {}); } @@ -1214,28 +1210,28 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); // Create a computation for the condition: repeat for 5 iterations. - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_shape, "state"); builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_shape, "state"); auto indvar = builder.GetTupleElement(state, 0); auto input_0 = builder.GetTupleElement(state, 1); auto input_1 = builder.GetTupleElement(state, 2); auto output = builder.Tanh(builder.Dot(input_0, input_1)); auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); - auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + builder.Tuple({indvar_next, input_0, input_1, output}); TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); auto init = builder.Tuple( {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); @@ -1268,9 +1264,9 @@ void BM_WhileLoop(int num_iters) { // Create while condition computation with 'loop_limit'. const int32 loop_limit = 100; - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); builder.Lt(iteration, builder.ConstantR0(loop_limit)); @@ -1278,9 +1274,9 @@ void BM_WhileLoop(int num_iters) { } // Create while body computation with unit loop increment. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto prev = builder.Parameter(0, loop_state_shape, "prev"); // TupleElement 0 auto iteration = builder.GetTupleElement(prev, 0); @@ -1294,12 +1290,12 @@ void BM_WhileLoop(int num_iters) { auto starts = builder.ConstantR1({0, 0, 0}); // UpdateSlice. auto out1 = builder.DynamicUpdateSlice(input, update, starts); - auto result = builder.Tuple({out0, out1}); + builder.Tuple({out0, out1}); body = builder.Build().ConsumeValueOrDie(); } // Create a While instruction. - ComputationBuilder builder(client, "while"); + XlaBuilder builder("while"); auto zero = builder.ConstantR0(0.0); auto input = builder.Broadcast(zero, {seq_len, 1024, 1024}); auto init = builder.Tuple({builder.ConstantR0(0), input}); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 9ad2a1985331b80625dd0687ea052300bc99e440..ff3418a128eed82b730a6602d6e3faba4ad7be32 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -144,7 +145,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, - ExecutableBuildOptions())); + ExecutableBuildOptions().set_hlo_profile(true))); Executable* executable = local_executable->executable(); HloExecutionProfile hlo_execution_profile( @@ -294,7 +295,8 @@ XLA_TEST_F(HloProfileTest, auto while_body_profile_start = std::find_if(profile_output_lines.begin(), profile_output_lines.end(), [](tensorflow::StringPiece s) { - return s.starts_with("Execution profile for body"); + return tensorflow::str_util::StartsWith( + s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.end()); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index 92b2b1ee778f8b0f8104e7d7ff27a5c11db59768..a9f2915b458b1816926de727b3da21982d06f6c0 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector flag_list; @@ -25,7 +29,38 @@ GTEST_API_ int main(int argc, char** argv) { return 2; } + // If the --benchmarks flag is passed in then only run the benchmarks, not the + // tests. + for (int i = 1; i < argc; i++) { + tensorflow::StringPiece arg(argv[i]); + if (arg == "--benchmarks" || + tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + const char* pattern = nullptr; + if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + pattern = argv[i] + strlen("--benchmarks="); + } else { + // Handle flag of the form '--benchmarks foo' (no '='). + if (i + 1 >= argc || + tensorflow::str_util::StartsWith(argv[i + 1], "--")) { + LOG(ERROR) << "--benchmarks flag requires an argument."; + return 2; + } + pattern = argv[i + 1]; + } + // Unfortunately Google's internal benchmark infrastructure has a + // different API than Tensorflow's. +#if defined(PLATFORM_GOOGLE) + base::SetFlag(&FLAGS_benchmarks, pattern); + RunSpecifiedBenchmarks(); +#else + tensorflow::testing::Benchmark::Run(pattern); +#endif + return 0; + } + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 6fa4c48e11d1102367b21bc21d4734466495ef0e..44f874cd2ae8e6f65dc282b8675f195ec9c09415 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -38,7 +38,7 @@ namespace xla { StatusOr> TextLiteralReader::ReadPath( tensorflow::StringPiece path) { - CHECK(!path.ends_with(".gz")) + CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = @@ -115,7 +115,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece value_string = pieces[1]; tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!coordinates_string.Consume("(")) { + if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); } diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 091fa0c3ec807a66449eca0bfbb141285b8eb532..0bc4045a5490319994b6cf24daf99fe856167507 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -75,6 +75,7 @@ cc_library( name = "replay_computation_library", srcs = ["replay_computation.cc"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -222,17 +223,3 @@ tf_cc_binary( "//tensorflow/core:lib", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index 97aacf6b39f83978e732060817cd93ede81ca782..0fa4b98d0a41a1e7c681bb2302da3b752315867b 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -70,17 +70,3 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) - -# ----------------------------------------------------------------------------- - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 89def5d5610cb9522a69297668b443b8c4e03fb5..b2f122982adf750106f034e7e786367720ebafcf 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -724,6 +724,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], *broadcast_dimensions)); break; } + case HloOpcode::kBroadcastDimOne: { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateBroadcastDimOne(shape, operands[0])); + break; + } case HloOpcode::kConcatenate: { optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, @@ -994,6 +1003,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands, *custom_call_target)); break; } + case HloOpcode::kHostCompute: { + optional channel_name; + optional cost_estimate_ns; + attrs["channel_name"] = {/*required=*/true, AttrTy::kString, + &channel_name}; + attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, + &cost_estimate_ns}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( + shape, operands, *channel_name, *cost_estimate_ns)); + break; + } case HloOpcode::kDot: { optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { @@ -1035,6 +1058,40 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } + case HloOpcode::kGather: { + optional> output_window_dims; + attrs["output_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; + optional> elided_window_dims; + attrs["elided_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; + optional> gather_dims_to_operand_dims; + attrs["gather_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &gather_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + optional> window_bounds; + attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, + &window_bounds}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateGather( + shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], + dim_numbers, *window_bounds)); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index b8c6b59204f897c7dc07b846370b5b776a19a808..57684b58346166f7e3ef9576f6cd8f70ab9dc389 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -56,6 +57,18 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } +)" +}, +// broadcast size-one dimensions +{ +"BroadcastDimOne", +R"(HloModule broadcast_dim_one_module + +ENTRY %broadcast-dim-one () -> f32[2,2] { + %constant = f32[1,2]{1,0} constant(f32[1,2] { { 1.1, 2.2 } }) + ROOT %broadcast-dim-one = f32[2,2]{1,0} broadcast-dim-one(f32[1,2]{1,0} %constant) +} + )" }, // pred constant @@ -716,6 +729,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] { ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) } +)" +}, +{ +"gather", +R"(HloModule StringifyGather + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); @@ -860,6 +885,18 @@ ENTRY dot { ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} } +)" +}, +{ +"gather", +R"(HloModule gather + +ENTRY Gather { + input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); @@ -870,7 +907,7 @@ class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(StringPiece(s).contains(expected)) + EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index eda5effbb92db92c9317a956497a00c0ec15c27c..62a353ad09af009e4abf47664a5c5f7bd70a049e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -66,6 +67,7 @@ struct Options { bool use_fake_data = false; bool print_result = true; int num_runs = 1; + bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) @@ -122,16 +124,21 @@ StatusOr> ReplayComputation( std::unique_ptr result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; + ExecutionOptions execution_options = CreateDefaultExecutionOptions(); + if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { + execution_options.mutable_debug_options()->set_xla_hlo_profile(true); + } + if (opts.print_result) { - TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( - computation, execute_arguments, - /*execution_options=*/nullptr, &profile)); + TF_ASSIGN_OR_RETURN( + result, client->ExecuteAndTransfer(computation, execute_arguments, + &execution_options, &profile)); } else { // If we're not printing the result, execute the computation but don't // bother retrieving the result. This can be a significant speedup. TF_RETURN_IF_ERROR(client ->Execute(computation, execute_arguments, - /*execution_options=*/nullptr, &profile) + &execution_options, &profile) .status()); } LOG(INFO) << "Execution took " @@ -191,6 +198,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), + tensorflow::Flag( + "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, + "Pass --xla_hlo_profile the last time we run the computation."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1f0c626bbb2d64ef4e67c9ec51485ae96ae73d04..e43498e381b8e63543e2ddda08ca7c0df91817e4 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" -#include #include #include @@ -244,8 +243,8 @@ string HumanReadableNumOps(double flops, double nanoseconds, static_cast(nano_flops * 1e9)); tensorflow::StringPiece sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (sp.ends_with("B") || // Ends in 'B', ignoring case - sp.ends_with("b")) { + if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case + tensorflow::str_util::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); @@ -292,7 +291,8 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, } int64 Product(tensorflow::gtl::ArraySlice xs) { - return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies()); + return std::accumulate(xs.begin(), xs.end(), static_cast(1), + std::multiplies()); } std::vector> CommonFactors( diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 08df5b12b3a53a138f56705531baa3333b23c5d8..2da9f9ed6f40fcf5b2512f974519df0b355da10f 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include #include "tensorflow/compiler/xla/status.h" @@ -427,32 +428,106 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name); template -bool c_all_of(Container container, Predicate predicate) { - return std::all_of(std::begin(container), std::end(container), predicate); +bool c_all_of(const Container& container, Predicate&& predicate) { + return std::all_of(std::begin(container), std::end(container), + std::forward(predicate)); +} + +template +bool c_any_of(const Container& container, Predicate&& predicate) { + return std::any_of(std::begin(container), std::end(container), + std::forward(predicate)); } template -OutputIterator c_transform(InputContainer input_container, +OutputIterator c_transform(const InputContainer& input_container, OutputIterator output_iterator, - UnaryOperation unary_op) { + UnaryOperation&& unary_op) { return std::transform(std::begin(input_container), std::end(input_container), - output_iterator, unary_op); + output_iterator, + std::forward(unary_op)); } template -OutputIterator c_copy_if(InputContainer input_container, +OutputIterator c_copy_if(const InputContainer& input_container, OutputIterator output_iterator, - UnaryPredicate predicate) { + UnaryPredicate&& predicate) { return std::copy_if(std::begin(input_container), std::end(input_container), - output_iterator, predicate); + output_iterator, std::forward(predicate)); +} + +template +OutputIterator c_copy(const InputContainer& input_container, + OutputIterator output_iterator) { + return std::copy(std::begin(input_container), std::end(input_container), + output_iterator); +} + +template +void c_sort(InputContainer& input_container) { + std::sort(std::begin(input_container), std::end(input_container)); } template -void c_sort(InputContainer& input_container, Comparator comparator) { - std::sort(input_container.begin(), input_container.end(), comparator); +void c_sort(InputContainer& input_container, Comparator&& comparator) { + std::sort(std::begin(input_container), std::end(input_container), + std::forward(comparator)); +} + +template +bool c_binary_search(const Sequence& sequence, T&& value) { + return std::binary_search(std::begin(sequence), std::end(sequence), + std::forward(value)); +} + +template +bool c_is_sorted(const C& c) { + return std::is_sorted(std::begin(c), std::end(c)); +} + +template +auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { + return std::adjacent_find(std::begin(c), std::end(c)); } +template +auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { + return std::find_if(std::begin(c), std::end(c), std::forward(pred)); +} + +template +auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) { + return std::find(std::begin(c), std::end(c), std::forward(value)); +} + +template +void c_reverse(Sequence& sequence) { + std::reverse(std::begin(sequence), std::end(sequence)); +} + +template +typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, + BinaryOp&& binary_op) { + return std::accumulate(std::begin(sequence), std::end(sequence), + std::forward(init), + std::forward(binary_op)); +} + +template +int64 FindIndex(const C& c, Value&& value) { + auto it = c_find(c, std::forward(value)); + return std::distance(c.begin(), it); +} + +// Returns true if `x` fits in 32-bits. +template +bool IsInt32(T x) { + // Following conversion rules: "the value is unchanged if it can be + // represented in the destination type (and bit-field width); otherwise, the + // value is implementation-defined." + return static_cast(x) == x; +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 6b136d333bbf079efd314833f46fe3b98743fbac..1439f1bcc5cec39203a7cb4b1f8604e7349382c6 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -6,7 +6,9 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") # xla_proto_library() is a convenience wrapper around cc_proto_library. -def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): +def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs): + if kwargs.get('use_grpc_plugin'): + kwargs['use_grpc_namespace'] = True cc_proto_library(name=name, srcs=srcs, deps=deps, @@ -16,6 +18,13 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): ), protoc="@protobuf_archive//:protoc", testonly=testonly, - visibility=visibility,) + visibility=visibility, + **kwargs) + +def xla_py_grpc_library(**kwargs): + # Note: we don't currently define any special targets for Python GRPC in OSS. + _ignore = kwargs + pass + ORC_JIT_MEMORY_MAPPER_TARGETS = [] diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 56162ab44e2e0e3e4478fe631888f243332dc1d8..f619b8dc24038af64a27fc0565c74447ca9d09cf 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -16,6 +16,7 @@ limitations under the License. syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -188,6 +189,12 @@ message DebugOptions { // directory. string xla_dump_per_pass_hlo_proto_to = 96; + // Generate calls to MKL-DNN in the CPU backend. + bool xla_cpu_use_mkl_dnn = 97; + + // Maximum kernel unroll factor for the GPU backend. + int32 xla_gpu_max_kernel_unroll_factor = 98; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; @@ -298,6 +305,11 @@ message ComputationStatsRequest { DebugOptions debug_options = 2; } +message ComputationGraphStatsRequest { + HloModuleProto computation = 1; + DebugOptions debug_options = 2; +} + message ComputationStatsResponse { ComputationStats stats = 1; } @@ -342,10 +354,22 @@ message ExecuteRequest { ExecutionOptions execution_options = 5; } +message ExecuteGraphRequest { + HloModuleProto computation = 1; + repeated GlobalDataHandle arguments = 2; + + // Options that affect how XLA compiles and runs code to service this request. + ExecutionOptions execution_options = 3; +} + message ExecuteParallelRequest { repeated ExecuteRequest requests = 1; } +message ExecuteGraphParallelRequest { + repeated ExecuteGraphRequest requests = 1; +} + message ExecuteResponse { GlobalDataHandle output = 1; ExecutionProfile profile = 2; @@ -396,6 +420,11 @@ message ComputeConstantRequest { repeated LiteralProto parameters = 4; } +message ComputeConstantGraphRequest { + HloModuleProto computation = 1; + Layout output_layout = 2; +} + message ComputeConstantResponse { // A LiteralProto is returned directly for this request, instead of a // ComputationDataHandle. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 3aea0217539b89b5d60ecfaf2605eee4b69af728..f18d53c6089e8d4411099be8fb0fb8c349ace4f7 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -355,17 +355,19 @@ message WindowDimension { // positions of the window in this dimension. int64 stride = 2; - // If positive, means the amount of padding with zeroes to add to the base - // area at the low end of this dimension; if negative, its negative means the - // number of elements removed from the low end of this dimension. For example, - // in the horizontal dimension of a rectangle, this would be the number of - // zeroes to pad on the left, given that indices increase when going right. + // If positive, means the amount of padding to add to the base area at the low + // end of this dimension; if negative, its negative means the number of + // elements removed from the low end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of padding + // values to pad on the left, given that indices increase when going right. + // The actual padding value depends upon the context. Convolution pads with + // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + // init value. int64 padding_low = 3; - // As padding_low, but on the high end of this dimension. For - // example, in the horizontal dimension of a rectangle, this would - // be the number of zeroes to pad on the right, given that indices - // increase when going right. + // As padding_low, but on the high end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of values to + // pad on the right, given that indices increase when going right. int64 padding_high = 4; // Dilation factor of the sliding window in this dimension. A dilation factor @@ -393,6 +395,37 @@ message Window { repeated WindowDimension dimensions = 1; } +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the gather_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in elided_window_dims + // then 0 + // else Out[output_window_dims[i++]] + repeated int64 output_window_dims = 1; + repeated int64 elided_window_dims = 2; + + // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It + // transforms the gather index looked up from the gather_indices tensor into + // the starting index in the input space. + repeated int64 gather_dims_to_operand_dims = 3; + + // The dimension in the gather_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; +} + // Operation requests that are all collected as a tagged union with a oneof // field in OpRequest. @@ -519,6 +552,20 @@ message CustomCallRequest { Shape shape = 4; } +message HostComputeRequest { + // Operand to the HostCompute. Supports tuple. + repeated ComputationDataHandle operands = 1; + + // Name used to identify HostSend/Recv channels. + string channel_name = 2; + + // Cost estimate in nanoseconds. + int64 cost_estimate_ns = 3; + + // The shape of any data returned by host. + Shape shape = 4; +} + message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -880,6 +927,13 @@ message RecvRequest { ChannelHandle channel_handle = 2; } +message GatherRequest { + ComputationDataHandle input = 1; + ComputationDataHandle gather_indices = 2; + GatherDimensionNumbers dimension_numbers = 3; + repeated int64 window_bounds = 4; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -957,7 +1011,9 @@ message OpRequest { FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; ConditionalRequest conditional_request = 44; - // Next: 45 + HostComputeRequest host_compute_request = 45; + GatherRequest gather_request = 46; + // Next: 47 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index bab37e8906e5c648acdc1556da7e5f4601776ff5..7e47516550068e8f38d2155e48e229c2ab77b488 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -8,6 +8,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("//tensorflow:tensorflow.bzl", "if_not_windows") py_library( name = "contrib_py", @@ -33,13 +34,13 @@ py_library( "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", - "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/fused_conv:fused_conv_py", "//tensorflow/contrib/gan", @@ -51,7 +52,6 @@ py_library( "//tensorflow/contrib/image:single_image_random_dot_stereograms_py", "//tensorflow/contrib/input_pipeline:input_pipeline_py", "//tensorflow/contrib/integrate:integrate_py", - "//tensorflow/contrib/kafka", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", "//tensorflow/contrib/kfac", @@ -63,7 +63,6 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", - "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", @@ -75,16 +74,20 @@ py_library( "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", + "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/py2tf", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", + "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/contrib/rpc", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/contrib/seq2seq:seq2seq_py", "//tensorflow/contrib/signal:signal_py", @@ -110,6 +113,15 @@ py_library( "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka", + ], + "//conditions:default": [], + }) + if_not_windows([ + "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", + "//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code ]), ) @@ -119,7 +131,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", @@ -133,7 +144,13 @@ cc_library( "//tensorflow/contrib/text:all_kernels", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", - ]), + ]) + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_kernels", + ], + "//conditions:default": [], + }), ) cc_library( @@ -142,12 +159,10 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", - "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", - "//tensorflow/contrib/kafka:kafka_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", @@ -158,17 +173,11 @@ cc_library( "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", + ] + select({ + "//tensorflow:with_kafka_support_windows_override": [], + "//tensorflow:with_kafka_support": [ + "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], - ), - visibility = ["//tensorflow:__subpackages__"], + "//conditions:default": [], + }), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 4f6f539027b040de7554d09fe9118ff97aa006f8..36cc5144d072893a950c1701fc46fb55510d1fd2 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=g-import-not-at-top # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import batching from tensorflow.contrib import bayesflow @@ -30,6 +33,7 @@ from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn from tensorflow.contrib import data from tensorflow.contrib import deprecated +from tensorflow.contrib import distribute from tensorflow.contrib import distributions from tensorflow.contrib import estimator from tensorflow.contrib import factorization @@ -60,11 +64,14 @@ from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor +from tensorflow.contrib import proto from tensorflow.contrib import quantization from tensorflow.contrib import quantize +from tensorflow.contrib import recurrent from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn +from tensorflow.contrib import rpc from tensorflow.contrib import saved_model from tensorflow.contrib import seq2seq from tensorflow.contrib import signal @@ -83,7 +90,9 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager -from tensorflow.contrib.lite.python import lite +if os.name != "nt": + from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs @@ -92,6 +101,7 @@ from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", globals(), "tensorflow.contrib.ffmpeg") +del os del LazyLoader del absolute_import diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 8dff93b4f825277dcf0a64aa3b96bd809d36e1e9..62d1b1cf079d04d50e4899cfd9ba1d405ee1efb9 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -45,16 +45,3 @@ tf_py_test( "//tensorflow/python:state_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 6658f0d9c13f6db17b25354cde2593d57f104f17..8add2aacff1d64f1617cd24167c4c6c6706044da 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -38,16 +38,15 @@ def _flatten_tensors(tensors): shape: the original shape of each element of input tensors Raises: - ValueError: tensors are empty or non-isomorphic. + ValueError: tensors are empty or non-isomorphic or have unknown shape. """ if not tensors: raise ValueError("tensors cannot be empty") shape = tensors[0].shape for tensor in tensors: shape = shape.merge_with(tensor.shape) - if shape.ndims is None: - raise ValueError("At least one of the tensors in 'tensors' must have " - "statically known rank.") + if not shape.is_fully_defined(): + raise ValueError("Tensors must have statically known shape.") if len(shape) != 1: reshaped = [] for t in tensors: diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index 47bab0a3670a90644972b2c961954a3036b8ecba..b3f5d92259df8475b205110dd3f0cee1cb5bde6f 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -36,6 +36,12 @@ from tensorflow.python.platform import tf_logging class AllReduceTest(test_util.TensorFlowTestCase): + def testFlattenTensorsShapesDefined(self): + x = array_ops.placeholder(types_pb2.DT_FLOAT, [None]) + with self.assertRaisesRegexp(ValueError, + "must have statically known shape"): + ar._flatten_tensors([x, x]) + def testRingPermutations(self): # 0 devices pred_by_c_d, rank_by_c_d = ar._ring_permutations(1, 0, []) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 4bff3c27d22c4550747a651a59909bdef80e8285..60306ebdc6cddb04e8807bfd495fa92a56e55ecd 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -38,20 +38,6 @@ cc_library( alwayslink = 1, ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # JAR with Java bindings to TF. android_library( name = "android_tensorflow_inference_java", diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 380a652435ad089f46f3ca80e4fd43097fd96e10..513d519eabbd54f46fde9ec0f004247c02277732 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system_helper.h" namespace tensorflow { namespace { @@ -228,9 +229,8 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { } string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { - string output(name); - StringPiece piece(output); - piece.Consume(prefix_); + StringPiece piece(name); + str_util::ConsumePrefix(&piece, prefix_); return piece.ToString(); } @@ -243,6 +243,11 @@ bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) { return AAssetDir_getNextFileName(dir.get()) != NULL; } +Status AssetManagerFileSystem::GetMatchingPaths(const string& pattern, + std::vector* results) { + return internal::GetMatchingPaths(this, Env::Default(), pattern, results); +} + Status AssetManagerFileSystem::NewWritableFile( const string& fname, std::unique_ptr* result) { return errors::Unimplemented("Asset storage is read only."); diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h index 665304b5eef1f8a3633c8c522259e20d744b1808..a87ff42ae217c429ecf5d2458b88b3431551ad97 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.h +++ b/tensorflow/contrib/android/asset_manager_filesystem.h @@ -66,6 +66,9 @@ class AssetManagerFileSystem : public FileSystem { Status DeleteDir(const string& d) override; Status RenameFile(const string& s, const string& t) override; + Status GetMatchingPaths(const string& pattern, + std::vector* results) override; + private: string RemoveAssetPrefix(const string& name); diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index a115d1610e2334a6626f29674f3dd195e3a3c648..ecf1a103d2981f409a4598d762fb26100217f779 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -75,7 +75,6 @@ target_link_libraries(tensorflow_inference include_directories( ${PREBUILT_DIR}/proto ${PREBUILT_DIR}/protobuf/include - ${PREBUILT_DIR}/nsync/public ${TENSORFLOW_ROOT_DIR}/tensorflow/contrib/makefile/downloads/eigen ${TENSORFLOW_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/tensorflow/contrib/py2tf/BUILD b/tensorflow/contrib/autograph/BUILD similarity index 75% rename from tensorflow/contrib/py2tf/BUILD rename to tensorflow/contrib/autograph/BUILD index d91220f6ddb859ff52d4e5853948cb667981009b..30dd846893c30b9205972bd5216cc1871ab03d76 100644 --- a/tensorflow/contrib/py2tf/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -15,16 +15,16 @@ filegroup( ) py_library( - name = "py2tf", + name = "autograph", srcs = [ "__init__.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/impl", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], diff --git a/tensorflow/contrib/py2tf/README.md b/tensorflow/contrib/autograph/README.md similarity index 87% rename from tensorflow/contrib/py2tf/README.md rename to tensorflow/contrib/autograph/README.md index cd50675ad57316b9c749c137e6acd30b91c10073..7e84f237dc9a83098f142a54c48cf5b6ba35aaaa 100644 --- a/tensorflow/contrib/py2tf/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,4 +1,4 @@ -# Py2TF +# Autograph A compiler for generating TensorFlow numeric and control flow ops from Python code. diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/autograph/__init__.py similarity index 57% rename from tensorflow/contrib/py2tf/__init__.py rename to tensorflow/contrib/autograph/__init__.py index 379fa7fd5c2a22b5b16a21cca8c2ea8afdcaeefa..3386c4eca4b93e850f6fe3c6239d29c61d787ece 100644 --- a/tensorflow/contrib/py2tf/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Py2TF compiles Python code into equivalent TensorFlow code. +"""Autograph compiles Python code into equivalent TensorFlow code. Equivalent here means that they have the same effect when executed. """ @@ -21,16 +21,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.impl.api import convert -from tensorflow.contrib.py2tf.impl.api import graph_ready -from tensorflow.contrib.py2tf.impl.api import to_code -from tensorflow.contrib.py2tf.impl.api import to_graph -from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError +# TODO(mdan): Bring only the relevant symbols to the top level. +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl.api import convert +from tensorflow.contrib.autograph.impl.api import converted_call +from tensorflow.contrib.autograph.impl.api import do_not_convert +from tensorflow.contrib.autograph.impl.api import RunMode +from tensorflow.contrib.autograph.impl.api import to_code +from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError' + 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', + 'to_code', 'to_graph', 'AutographParseError' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD similarity index 70% rename from tensorflow/contrib/py2tf/converters/BUILD rename to tensorflow/contrib/autograph/converters/BUILD index 93c751b28dae3aa480aed839029bd37a2f47056b..8f9bffa55e44e4942bb3845945b3d440c7957cc9 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -24,10 +24,13 @@ py_library( "continue_statements.py", "control_flow.py", "decorators.py", - "for_loops.py", + "ifexp.py", "list_comprehension.py", + "lists.py", "logical_expressions.py", + "name_scopes.py", "side_effect_guards.py", + "single_return.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -45,8 +48,10 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ ":converters", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -56,9 +61,9 @@ py_test( name = "asserts_test", srcs = ["asserts_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -69,7 +74,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -78,20 +82,22 @@ py_test( name = "builtin_functions_test", srcs = ["builtin_functions_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( name = "call_trees_test", + size = "large", srcs = ["call_trees_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], ) @@ -102,7 +108,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -113,7 +118,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -124,17 +128,16 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( - name = "for_loops_test", - srcs = ["for_loops_test.py"], + name = "name_scopes_test", + srcs = ["name_scopes_test.py"], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -145,7 +148,16 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "lists_test", + srcs = ["lists_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", "//tensorflow/python:client_testlib", ], ) @@ -156,7 +168,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -165,9 +176,35 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", + tags = [ + # TODO(mdan): Fix. + "flaky", + "notap", + ], + deps = [ + ":test_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "single_return_test", + srcs = ["single_return_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "ifexp_test", + srcs = ["ifexp_test.py"], + srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/converters/__init__.py b/tensorflow/contrib/autograph/converters/__init__.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/__init__.py rename to tensorflow/contrib/autograph/converters/__init__.py index ca10896ee5c6c23d9b20ff23add9945de68e5bf9..e4e8eda42f655e204310eaa9defdd5c90bf06e15 100644 --- a/tensorflow/contrib/py2tf/converters/__init__.py +++ b/tensorflow/contrib/autograph/converters/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Code converters used by Py2TF.""" +"""Code converters used by Autograph.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/py2tf/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py similarity index 86% rename from tensorflow/contrib/py2tf/converters/asserts.py rename to tensorflow/contrib/autograph/converters/asserts.py index 5b9b8e772bed82df2429fd6cb94dbf7b565e22b3..2d9e2c58e3afcef5c18f477a7a29e518e98e672e 100644 --- a/tensorflow/contrib/py2tf/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,15 +20,13 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class AssertsTransformer(transformer.Base): """Transforms Print nodes to Call so they can be handled as functions.""" - # pylint:disable=invalid-name - def visit_Assert(self, node): self.generic_visit(node) @@ -44,9 +42,7 @@ class AssertsTransformer(transformer.Base): elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: - raise NotImplementedError('Can only convert string messages for now.') - - # pylint:enable=invalid-name + raise NotImplementedError('can only convert string messages for now.') def transform(node, context): diff --git a/tensorflow/contrib/py2tf/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py similarity index 90% rename from tensorflow/contrib/py2tf/converters/asserts_test.py rename to tensorflow/contrib/autograph/converters/asserts_test.py index 6611f2777a93a7e819c8becfa06a09b27f4e6aaf..cc913febe8d0f411588af69b87ec52ce58f4469c 100644 --- a/tensorflow/contrib/py2tf/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py new file mode 100644 index 0000000000000000000000000000000000000000..5dfb7a59d51859983c7e0c5549facc3a48b2d285 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes break statements by de-sugaring into a control boolean.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +class BreakCanonicalizationTransformer(transformer.Base): + """Canonicalizes break statements into additional conditionals.""" + + def __init__(self, context): + super(BreakCanonicalizationTransformer, self).__init__(context) + # This is a stack structure, to correctly process nested loops. + # Each item is a list [break_used, break_variable_name] + self.break_uses = [] + + def visit_Break(self, node): + self.break_uses[-1][0] = True + template = """ + var_name = True + continue + """ + return templates.replace(template, var_name=self.break_uses[-1][1]) + + def visit_While(self, node): + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + break_var = self.context.namer.new_symbol('break_requested', + scope.referenced) + + self.break_uses.append([False, break_var]) + node = self.generic_visit(node) + if self.break_uses[-1][0]: + template = """ + var_name = False + while original_test and not var_name: + original_body + else: + original_orelse + """ + node = templates.replace( + template, + var_name=break_var, + original_test=node.test, + original_body=node.body, + original_orelse=node.orelse) + self.break_uses.pop() + + return node + + def visit_For(self, node): + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + break_var = self.context.namer.new_symbol('break_requested', + scope.referenced) + + self.break_uses.append([False, break_var]) + node = self.generic_visit(node) + if self.break_uses[-1][0]: + template = """ + var_name = False + original_for + """ + node = templates.replace( + template, + var_name=break_var, + original_for=node) + extra_cond = templates.replace_as_expression( + 'not var_name', var_name=break_var) + new_for_node = node[1] + anno.setanno(new_for_node, 'extra_cond', extra_cond) + self.break_uses.pop() + + return node + + +def transform(node, context): + return BreakCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/break_statements_test.py rename to tensorflow/contrib/autograph/converters/break_statements_test.py index 095fcdff07d44ecc6b9bb7f8d3e2c7c43df72a02..dd4914a022f57b3bb4a19ec132f311f12269fa9e 100644 --- a/tensorflow/contrib/py2tf/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import break_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py similarity index 76% rename from tensorflow/contrib/py2tf/converters/builtin_functions.py rename to tensorflow/contrib/autograph/converters/builtin_functions.py index 2eb00f90575920ac948e799b0e97a9cfccb42fad..317711a866f731de1b497295a2752dee0eb544f5 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,8 +20,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class BuiltinFunctionTransformer(transformer.Base): @@ -34,25 +34,26 @@ class BuiltinFunctionTransformer(transformer.Base): def __init__(self, context): super(BuiltinFunctionTransformer, self).__init__(context) - # pylint:disable=invalid-name - - def _convert_len(self, node): + def _convert_builtin(self, node): template = """ - tf.shape(args)[0] + ag__.utils.dynamic_builtin(func, args) """ - return templates.replace(template, args=node.args)[0].value + return templates.replace(template, func=node.func, args=node.args)[0].value def _convert_print(self, node): template = """ - py2tf_utils.call_print(args) + ag__.utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. - if isinstance(node.func, gast.Name) and node.func.id == 'len': - return self._convert_len(node) + # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. + if (isinstance(node.func, gast.Name) and + node.func.id in ('len', 'range', 'xrange')): + return self._convert_builtin(node) + # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': return self._convert_print(node) return node @@ -69,8 +70,6 @@ class BuiltinFunctionTransformer(transformer.Base): function_call = templates.replace(template, fname='print', args=args)[0] return self.visit(function_call) - # pylint:enable=invalid-name - def transform(node, context): return BuiltinFunctionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py similarity index 64% rename from tensorflow/contrib/py2tf/converters/builtin_functions_test.py rename to tensorflow/contrib/autograph/converters/builtin_functions_test.py index b279ff77ef10b96586d3d68585adb0d5424afb90..30272409df322560b04ba75b3e1cb6f9ad5ff0af 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -22,12 +22,10 @@ import sys import six -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -47,7 +45,9 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): sess.run( result.test_fn(constant_op.constant([0, 0, 0])))) - def test_print_with_op(self): + self.assertEqual(3, result.test_fn([0, 0, 0])) + + def test_print(self): def test_fn(a): print(a) @@ -55,14 +55,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a') + result.test_fn(constant_op.constant('a')) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a\n') finally: @@ -70,41 +68,19 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): def test_print_with_op_multiple_values(self): - def test_fn(a, b): - print(a, b) - - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) - - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: - with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn('a', 1) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_print_with_py_func(self): - def test_fn(a, b, c): print(a, b, c) node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include logging_ops.Print here, to verify - # that py_func is used. - with self.compiled(node, script_ops.py_func) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a', 1, [2, 3]) + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3]) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') finally: diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py similarity index 69% rename from tensorflow/contrib/py2tf/converters/call_trees.py rename to tensorflow/contrib/autograph/converters/call_trees.py index 1050ba654c63bb52c1c5e71c981a6a0baa3fc987..685fd39d7cd7d0b15c9240032c9767392ec3642c 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -22,17 +22,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import types +from collections import namedtuple import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect +class FunctionInfo(namedtuple('FunctionInfo', ('dtype',))): + pass + + +# TODO(mdan): Move this to config.py. +KNOWN_NUMPY_FUNCTIONS = { + ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'), +} + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -72,9 +84,8 @@ class CallTreeTransformer(transformer.Base): self.uncompiled_modules = uncompiled_modules self.nocompile_decorators = nocompile_decorators - # pylint:disable=invalid-name - def _resolve_name(self, node): + """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): @@ -99,7 +110,19 @@ class CallTreeTransformer(transformer.Base): (owner_type, node.attr)) return None + def _function_is_compilable(self, target_entity): + """Determines whether an entity can be compiled at all.""" + # TODO(mdan): This is just a placeholder. Implement. + return not inspect_utils.isbuiltin(target_entity) + def _should_compile(self, node, fqn): + """Determines whether an entity should be compiled in the context.""" + # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. + module_name = fqn[0] + for mod in self.uncompiled_modules: + if module_name.startswith(mod[0] + '.'): + return False + for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False @@ -123,7 +146,7 @@ class CallTreeTransformer(transformer.Base): # Inspect the target function decorators. If any include a @convert # or @graph_ready annotation, then they must be called as they are. # TODO(mdan): This may be quite heavy. - # To parse and re-analize each function for every call site could be quite + # To parse and re-analyze each function for every call site could be quite # wasteful. Maybe we could cache the parsed AST? try: target_node, _ = parser.parse_entity(target_entity) @@ -141,33 +164,6 @@ class CallTreeTransformer(transformer.Base): return True - def _determine_function_owner(self, m): - # TODO(mdan): The parent type should be known at analysis. Use that instead. - if hasattr(m, 'im_class'): # Python 2 - return m.im_class - if hasattr(m, '__qualname__'): # Python 3 - # Object attributes: should be bound to "self". - if hasattr(m, '__self__'): - return type(m.__self__) - - # Class attributes: should have the owner name in their namespace. - qn = m.__qualname__.split('.') - if len(qn) < 2: - return None - owner_name, func_name = qn[-2:] - if func_name != m.__name__: - raise ValueError('Inconsistent names detected ' - '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % - (func_name, m.__name__, m)) - if owner_name == '': - return None - if owner_name not in self.context.namespace: - raise ValueError( - 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % - (owner_name, m, self.context.namespace)) - return self.context.namespace[owner_name] - return None - def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') @@ -182,7 +178,11 @@ class CallTreeTransformer(transformer.Base): target_fqn, live_entity=target_entity) do_rename = True else: - owner_type = self._determine_function_owner(target_entity) + if anno.hasanno(node.func, 'parent_type'): + owner_type = anno.getanno(node.func, 'parent_type') + else: + # Fallback - not reliable. + owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) @@ -196,15 +196,56 @@ class CallTreeTransformer(transformer.Base): return node def _wrap_to_py_func_no_return(self, node): - # TODO(mdan): Properly handle varargs, kwargs, etc. + # TODO(mdan): Properly handle varargs, etc. template = """ - py2tf_utils.wrap_py_func(func, None, (original_args,), True) + ag__.utils.wrap_py_func(func, None, (args,), kwargs, True) """ - return templates.replace(template, func=node.func, original_args=node.args) + return templates.replace( + template, + func=node.func, + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _wrap_to_py_func_single_return(self, node, dtype): + # TODO(mdan): Properly handle varargs, etc. + template = """ + ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False) + """ + return templates.replace_as_expression( + template, + func=node.func, + dtype=parser.parse_expression(dtype), + args=node.args, + kwargs=ast_util.keywords_to_dict(node.keywords)) + + def _insert_dynamic_conversion(self, node): + """Inlines a dynamic conversion for a dynamic function.""" + # TODO(mdan): Pass information on the statically compiled functions. + # Having access to the statically compiled functions can help avoid + # unnecessary compilation. + # For example, this would lead to function `a` being compiled twice: + # + # def a(): + # v = b + # b() + # def b(): + # a() + # + # This is really a problem with recursive calls, which currently can + # only be gated by a static condition, and should be rare. + # TODO(mdan): It probably makes sense to use dynamic conversion every time. + # Before we could convert all the time though, we'd need a reasonable + # caching mechanism. + template = """ + ag__.converted_call(func, True, False, {}, args) + """ + call_expr = templates.replace(template, func=node.func, args=node.args) + new_call = call_expr[0].value + # TODO(mdan): Improve the template mechanism to better support this. + new_call.keywords = node.keywords + return new_call - def _function_is_compilable(self, target_entity): - # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_entity, types.BuiltinFunctionType) + # pylint:disable=invalid-name def visit_Expr(self, node): if isinstance(node.value, gast.Call): @@ -239,15 +280,24 @@ class CallTreeTransformer(transformer.Base): self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') + if anno.hasanno(node.func, 'fqn'): + target_fqn = anno.getanno(node.func, 'fqn') + else: + target_fqn = None if self._function_is_compilable(target_entity): node = self._rename_compilable_function(node) + elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS: + # TODO(mdan): Should we replace these with equivalent TF ops instead? + node = self._wrap_to_py_func_single_return( + node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype) else: - raise NotImplementedError('py_func with return values') + raise NotImplementedError( + 'py_func with return values (unknown function)') else: if self.context.recursive: - raise NotImplementedError('Could not resolve target function.') + node = self._insert_dynamic_conversion(node) else: - # TODO(mdan): Double check. Is this reachable code? + # Unresolved functions are allowed in non-recursive mode. pass return node diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py similarity index 75% rename from tensorflow/contrib/py2tf/converters/call_trees_test.py rename to tensorflow/contrib/autograph/converters/call_trees_test.py index 777648dc0b31863227262fbf931aba680bb4ed98..303dd54a4ee49de27fad0c5cdc2d6274abfe0fa8 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import converter_test_base +import numpy as np + +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -30,7 +34,7 @@ class CallTreesTest(converter_test_base.TestCase): def test_basic(self): def test_fn_1(_): - raise ValueError('This should not be called in the compiled verison.') + raise ValueError('This should not be called in the compiled version.') def renamed_test_fn_1(a): return a + 1 @@ -47,6 +51,21 @@ class CallTreesTest(converter_test_base.TestCase): result.renamed_test_fn_1 = renamed_test_fn_1 self.assertEquals(3, result.test_fn_2(1)) + def test_dynamic_function(self): + + def test_fn_1(): + raise ValueError('This should be masked by the mock.') + + def test_fn_2(f): + return f() + 3 + + node = self.parse_and_analyze(test_fn_2, {}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + # 10 = 7 (from the mock) + 3 (from test_fn_2) + self.assertEquals(10, result.test_fn_2(test_fn_1)) + def test_simple_methods(self): class TestClass(object): @@ -59,6 +78,7 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, + namer=converter_test_base.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) node = call_trees.transform(node, self.ctx, (), ()) @@ -89,6 +109,20 @@ class CallTreesTest(converter_test_base.TestCase): sess.run(sess.graph.get_operations()[0]) self.assertEquals('bar', a.foo) + def test_py_func_wrap_known_function(self): + + def test_fn(): + return np.random.binomial(2, 0.5) + + node = self.parse_and_analyze(test_fn, {'np': np}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node, dtypes.int64) as result: + result.np = np + with self.test_session() as sess: + self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) + self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) + def test_uncompiled_modules(self): def test_fn(a): diff --git a/tensorflow/contrib/py2tf/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py similarity index 94% rename from tensorflow/contrib/py2tf/converters/continue_statements.py rename to tensorflow/contrib/autograph/converters/continue_statements.py index 4069a678b118b56b59d2e5491bb80cf52efd8143..4299a8a9d59715d032222c47794bbb4393f34ce6 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class ContinueCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py similarity index 95% rename from tensorflow/contrib/py2tf/converters/continue_statements_test.py rename to tensorflow/contrib/autograph/converters/continue_statements_test.py index a598dcd1aed29478b7e3fe27e3c1b20010247dd9..bcbb316d7459aa5a25bb0bd128cd6e359a393288 100644 --- a/tensorflow/contrib/py2tf/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.platform import test diff --git a/tensorflow/contrib/py2tf/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py similarity index 69% rename from tensorflow/contrib/py2tf/converters/control_flow.py rename to tensorflow/contrib/autograph/converters/control_flow.py index d53e3e4fd6d87004cbe55bd430346ad263e898ea..2e26cdb3d9387d358e0225555506f199e9945d0b 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,11 +20,12 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -49,11 +50,6 @@ class ControlFlowTransformer(transformer.Base): def __init__(self, context): super(ControlFlowTransformer, self).__init__(context) - # pylint:disable=invalid-name - - def visit_For(self, node): - assert False, 'for statement should have been canonicalized at this point' - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -82,7 +78,7 @@ class ControlFlowTransformer(transformer.Base): def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ - results = py2tf_utils.run_cond(test, body_name, orelse_name) + results = ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, @@ -92,7 +88,7 @@ class ControlFlowTransformer(transformer.Base): orelse_name=orelse_name) else: template = """ - py2tf_utils.run_cond(test, body_name, orelse_name) + ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) @@ -170,7 +166,22 @@ class ControlFlowTransformer(transformer.Base): body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced + cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE) + cond_closure = set() + for s in cond_scope.referenced: + for root in s.support_set: + if root not in body_scope.created: + cond_closure.add(root) + state = list(body_closure) + if not state: + # TODO(mdan): Implement this properly. + # To complete this statement, we need to check whether any variable + # created inside the body scope is used before being modified outside the + # scope. This should be done during activity analysis, and in general + # should cover the case where variables may not be initialized. + raise ValueError('cannot convert while loop: no outputs') + state_ssf = [ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state ] @@ -196,7 +207,8 @@ class ControlFlowTransformer(transformer.Base): def body_name(state_ssf): body return state_ssf, - state_ast_tuple = py2tf_utils.run_while(test_name, body_name, [state]) + state_ast_tuple = ag__.while_loop( + test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( template, @@ -208,11 +220,67 @@ class ControlFlowTransformer(transformer.Base): test=test, body_name=self.context.namer.new_symbol('loop_body', body_scope.referenced), - body=node_body) + body=node_body, + extra_deps=tuple(s.ast() for s in cond_closure), + ) return node - # pylint:enable=invalid-name + def visit_For(self, node): + self.generic_visit(node) + + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + body_closure = body_scope.modified - body_scope.created + all_referenced = body_scope.referenced + + state = list(body_closure) + + state_ssf = [ + self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + ] + ssf_map = { + name: ssf + for name, ssf in zip(state, state_ssf) + if str(name) != ssf + } + + if len(state) == 1: + state = state[0] + state_ssf = state_ssf[0] + state_ast_tuple = state + else: + state_ast_tuple = gast.Tuple([n.ast() for n in state], None) + + node_body = ast_util.rename_symbols(node.body, ssf_map) + if anno.hasanno(node, 'extra_cond'): + extra_cond = anno.getanno(node, 'extra_cond') + extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) + else: + extra_cond = parser.parse_expression('True') + + template = """ + def extra_cond_name(state_ssf): + return extra_cond_expr + def body_name(iterate, state_ssf): + body + return state_ssf, + state_ast_tuple = ag__.for_loop( + iterated, extra_cond_name, body_name, (state,)) + """ + node = templates.replace( + template, + state=state, + state_ssf=state_ssf, + state_ast_tuple=state_ast_tuple, + iterated=node.iter, + iterate=node.target, + extra_cond_name=self.context.namer.new_symbol('extra_cond', + all_referenced), + extra_cond_expr=extra_cond, + body_name=self.context.namer.new_symbol('loop_body', all_referenced), + body=node_body) + + return node def transform(node, context): diff --git a/tensorflow/contrib/py2tf/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py similarity index 58% rename from tensorflow/contrib/py2tf/converters/control_flow_test.py rename to tensorflow/contrib/autograph/converters/control_flow_test.py index b785b284a7fb7a0257551326c88b44a341b295ba..c5610b16b4e5de374f404307d3583660707d5e0b 100644 --- a/tensorflow/contrib/py2tf/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -18,9 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import control_flow -from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test @@ -94,6 +95,77 @@ class ControlFlowTest(converter_test_base.TestCase): with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + def test_simple_for(self): + + def test_fn(l): + s1 = 0 + s2 = 0 + for e in l: + s1 += e + s2 += e * e + return s1, s2 + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + l = [1, 2, 3] + self.assertEqual( + test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) + l = [] + self.assertEqual( + test_fn(l), + sess.run( + result.test_fn( + constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) + + def test_for_single_var(self): + + def test_fn(l): + s = 0 + for e in l: + s += e + return s + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + with self.test_session() as sess: + l = [1, 2, 3] + self.assertEqual( + test_fn(l), sess.run(result.test_fn(constant_op.constant(l)))) + l = [] + self.assertEqual( + test_fn(l), + sess.run( + result.test_fn( + constant_op.constant(l, shape=(0,), dtype=dtypes.int32)))) + + def test_for_with_iterated_expression(self): + + eval_count = [0] + + def count_evals(x): + eval_count[0] += 1 + return x + + def test_fn(n): + s = 0 + for e in count_evals(range(n)): + s += e + return s + + node = self.parse_and_analyze(test_fn, {'count_evals': count_evals}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node) as result: + result.count_evals = count_evals + self.assertEqual(test_fn(5), result.test_fn(5)) + # count_evals ran twice, once for test_fn and another for result.test_fn + self.assertEqual(eval_count[0], 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py similarity index 58% rename from tensorflow/contrib/py2tf/converters/converter_test_base.py rename to tensorflow/contrib/autograph/converters/converter_test_base.py index 67747183dd323a799a04943ce4c7fe8c4093d002..23b61cf78155b376f5bd1760f9a45669c2589679 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/converters/converter_test_base.py @@ -21,14 +21,16 @@ from __future__ import print_function import contextlib import imp -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.pyct import compiler -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test @@ -52,26 +54,51 @@ class FakeNamer(object): return ('renamed_%s' % '_'.join(original_fqn)), True +class FakeNoRenameNamer(FakeNamer): + + def compiled_function_name(self, original_fqn, **_): + return str(original_fqn), False + + class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager def compiled(self, node, *symbols): - source = '' + source = None + + self.dynamic_calls = [] + def converted_call(*args): + """Mock version of api.converted_call.""" + self.dynamic_calls.append(args) + return 7 + try: result, source = compiler.ast_to_object(node) - result.tf = self.make_fake_tf(*symbols) - result.py2tf_utils = utils + result.tf = self.make_fake_mod('fake_tf', *symbols) + fake_ag = self.make_fake_mod('fake_ag', converted_call) + fake_ag.__dict__.update(operators.__dict__) + fake_ag.__dict__['utils'] = utils + result.__dict__['ag__'] = fake_ag yield result except Exception: # pylint:disable=broad-except - print('Offending compiled code:\n%s' % source) + if source is None: + print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) + else: + print('Offending compiled code:\n%s' % source) raise - def make_fake_tf(self, *symbols): - fake_tf = imp.new_module('fake_tf') + def make_fake_mod(self, name, *symbols): + fake_mod = imp.new_module(name) for s in symbols: - setattr(fake_tf, s.__name__, s) - return fake_tf + if hasattr(s, '__name__'): + setattr(fake_mod, s.__name__, s) + elif hasattr(s, 'name'): + # This is a bit of a hack, but works for things like tf.int32 + setattr(fake_mod, s.name, s) + else: + raise ValueError('can not attach %s - what should be its name?' % s) + return fake_mod def attach_namespace(self, module, **ns): for k, v in ns.items(): @@ -83,6 +110,7 @@ class TestCase(test.TestCase): namer=None, arg_types=None, include_type_analysis=True, + owner_type=None, recursive=True): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( @@ -92,7 +120,9 @@ class TestCase(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, - recursive=recursive) + owner_type=owner_type, + recursive=recursive, + type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py similarity index 57% rename from tensorflow/contrib/py2tf/converters/decorators.py rename to tensorflow/contrib/autograph/converters/decorators.py index 3f620c1cd2d9b75f82410754a7e812e13eabe3ae..92445f31746cf94856ea43893f99a2ba60355fb5 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,8 +24,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import pretty_printer class DecoratorsTransformer(gast.NodeTransformer): @@ -33,6 +33,7 @@ class DecoratorsTransformer(gast.NodeTransformer): def __init__(self, remove_decorators): self.remove_decorators = remove_decorators + self.additional_dependencies = set() # pylint:disable=invalid-name @@ -44,13 +45,38 @@ class DecoratorsTransformer(gast.NodeTransformer): dec_func = dec.func else: dec_func = dec + + # Special cases. + # TODO(mdan): Is there any way we can treat these more generically? + # We may want to forego using decorators altogether if we can't + # properly support them. + if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',): + # Assumption: decorators are only visible in the AST when converting + # a function inline (via another decorator). + # In that case, the converted function is no longer part of the + # original object that it was declared into. + # This is currently verified by tests. + continue + if not anno.hasanno(dec_func, 'live_val'): raise ValueError( 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) + dec_value = anno.getanno(dec_func, 'live_val') if dec_value not in self.remove_decorators: - kept_decorators.append(dec) - node.decorator_list = kept_decorators + kept_decorators.append((dec, dec_value)) + + for _, dec_value in kept_decorators: + if dec_value.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_value, + dec_value.__module__)) + else: + self.additional_dependencies.add(dec_value) + + node.decorator_list = [dec for dec, _ in kept_decorators] return node # pylint:enable=invalid-name @@ -59,4 +85,4 @@ class DecoratorsTransformer(gast.NodeTransformer): def transform(node, remove_decorators): transformer = DecoratorsTransformer(remove_decorators) node = transformer.visit(node) - return node + return node, transformer.additional_dependencies diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9c01f689127dbedad7669c65b03e7da071b2d64d --- /dev/null +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for decorators module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.python.platform import test + + +# The Python parser only briefly captures decorators into the AST. +# The interpreter desugars them on load, and the decorated function loses any +# trace of the decorator (which is normally what you would expect, since +# they are meant to be transparent). +# However, decorators are still visible when you analyze the function +# from inside a decorator, before it was applied - as is the case +# with our conversion decorators. + + +def simple_decorator(f): + return lambda a: f(a) + 1 + + +def self_removing_decorator(removing_wrapper): + def decorator(f): + @wraps(f) + def wrapper(*args): + # This removing wrapper is defined in the test below. This setup is so + # intricate just to simulate how we use the transformer in practice. + transformed_f = removing_wrapper(f, (self_removing_decorator,)) + return transformed_f(*args) + 1 + return wrapper + return decorator + + +class DecoratorsTest(converter_test_base.TestCase): + + def _remover_wrapper(self, f, remove_decorators): + namespace = { + 'self_removing_decorator': self_removing_decorator, + 'simple_decorator': simple_decorator + } + node = self.parse_and_analyze(f, namespace) + node, _ = decorators.transform(node, remove_decorators=remove_decorators) + result, _ = compiler.ast_to_object(node) + return getattr(result, f.__name__) + + def test_noop(self): + + def test_fn(a): + return a + + node = self.parse_and_analyze(test_fn, {}) + node, deps = decorators.transform(node, remove_decorators=()) + result, _ = compiler.ast_to_object(node) + + self.assertFalse(deps) + self.assertEqual(1, result.test_fn(1)) + + def test_function(self): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(a): + return a + + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, test_fn(1)) + + def test_method(self): + + class TestClass(object): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(self, a): + return a + + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, TestClass().test_fn(1)) + + def test_multiple_decorators(self): + + class TestClass(object): + + # Note that reversing the order of this two doesn't work. + @classmethod + @self_removing_decorator(self._remover_wrapper) + def test_fn(cls, a): + return a + + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, TestClass.test_fn(1)) + + def test_nested_decorators(self): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(a): + @simple_decorator + def inner_fn(b): + return b + 11 + return inner_fn(a) + + with self.assertRaises(ValueError): + test_fn(1) + + # TODO(mdan): Uncomment this test once converter_test_base is updated. + # (can't do it now because it has unrelated pending changes) + # def test_nested_decorators(self): + # + # @self_removing_decorator(self._remover_wrapper) + # def test_fn(a): + # @imported_decorator + # def inner_fn(b): + # return b + 11 + # return inner_fn(a) + # + # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + # self.assertEqual(14, test_fn(1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/autograph/converters/ifexp.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/printing.py rename to tensorflow/contrib/autograph/converters/ifexp.py index 95a62bd80b5f4854e6a062df18d882f7bd495555..616d222762e09feeba1809f119d915dfbe522283 100644 --- a/tensorflow/contrib/py2tf/utils/printing.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -12,36 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow printing support utilities.""" +"""Canonicalizes the ternary conditional operator.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import py_func -from tensorflow.python.ops import logging_ops +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False +class IfExp(transformer.Base): + """Canonicalizes all IfExp nodes into plain conditionals.""" + def visit_IfExp(self, node): + template = """ + ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,)) + """ + desugared_ifexp = templates.replace_as_expression( + template, test=node.test, body=node.body, orelse=node.orelse) + return desugared_ifexp -def call_print(*values): - """Compiled counterpart of the print builtin. - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. +def transform(node, context): + """Desugar IfExp nodes into plain conditionals. Args: - *values: values to print + node: an AST node to transform + context: a context object + Returns: - A dummy value indicating the print completed. If tf. + new_node: an AST with no IfExp nodes, only conditionals. """ - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - return py_func.wrap_py_func(print, None, values, use_dummy_return=True) + node = IfExp(context).visit(node) + return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e --- /dev/null +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -0,0 +1,106 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ifexp module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.python.platform import test + + +class IfExpTest(converter_test_base.TestCase): + + def compiled_fn(self, test_fn, *args): + node = self.parse_and_analyze(test_fn, {}) + node = ifexp.transform(node, self.ctx) + module = self.compiled(node, *args) + return module + + def test_simple(self): + + def test_fn(x): + return 1 if x else 0 + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [0, 1]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_fn(self): + + def f(x): + return 3 * x + + def test_fn(x): + y = f(x * x if x > 0 else x) + return y + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + result.f = f + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_exp(self): + + def test_fn(x): + return x * x if x > 0 else x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_nested(self): + + def test_fn(x): + return x * x if x > 0 else x if x else 1 + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 0, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_in_cond(self): + + def test_fn(x): + if x > 0: + return x * x if x < 5 else x * x * x + return -x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_assign_in_cond(self): + + def test_fn(x): + if x > 0: + x = -x if x < 5 else x + return x + + with self.compiled_fn(test_fn) as result: + result.autograph_util = utils + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension.py rename to tensorflow/contrib/autograph/converters/list_comprehension.py index e8744831100e4852919b5cd1253b74acea4d790d..d7f292015164e047d054c5d1fb0b391e960bb73d 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,9 +31,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer class ListCompCanonicalizationTransformer(transformer.Base): diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py similarity index 93% rename from tensorflow/contrib/py2tf/converters/list_comprehension_test.py rename to tensorflow/contrib/autograph/converters/list_comprehension_test.py index 025fac11e41e6771fbb9b80ff3da70dc3ceec73e..4758671f5ec83c26cfa54be0ef68f5f564094f6c 100644 --- a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import list_comprehension +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import list_comprehension from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py new file mode 100644 index 0000000000000000000000000000000000000000..6dda554acc6331830e4a086f2d51aefcdf2ecd91 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -0,0 +1,106 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for list operations. + +This includes converting Python lists to TensorArray/TensorList. +""" + +# TODO(mdan): Elaborate the logic here. +# TODO(mdan): Does it even make sense to attempt to try to use TAs? +# The current rule (always convert to TensorArray) is naive and insufficient. +# In general, a better mechanism could look like: +# * convert to TensorList by default +# * leave as Python list if the user explicitly forbids it +# * convert to TensorArray only when complete write once behavior can be +# guaranteed (e.g. list comprehensions) + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.python.framework import dtypes + + +class ListTransformer(transformer.Base): + """Converts lists and related operations to their TF counterpart.""" + + def _empty_list(self, node): + if not anno.hasanno(node, 'element_type'): + raise NotImplementedError( + 'type inference for empty lists is not yet supported; ' + 'use set_element_type(, ) to continue') + dtype = anno.getanno(node, 'element_type') + if not isinstance(dtype, dtypes.DType): + # TODO(mdan): Allow non-TF dtypes? + # That would be consistent with the dynamic dispatch pattern, but + # we must make sure that doesn't become confusing. + raise NotImplementedError('element type "%s" not yet supported' % dtype) + + dtype_name = dtype.name + # TODO(mdan): Does it ever make sense not to use tensor lists? + template = """ + tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + """ + return templates.replace_as_expression(template, dtype_name=dtype_name) + + def _pre_populated_list(self, node): + raise NotImplementedError('pre-populated lists') + + def visit_Expr(self, node): + node = self.generic_visit(node) + if isinstance(node.value, gast.Call): + call_node = node.value + + if not anno.hasanno(call_node.func, anno.Basic.QN): + return node + qn = anno.getanno(call_node.func, anno.Basic.QN) + + if qn.qn[-1] == 'append' and (len(call_node.args) == 1): + template = """ + target = ag__.utils.dynamic_list_append(target, element) + """ + node = templates.replace( + template, + target=qn.parent.ast(), + element=call_node.args[0]) + return node + + def visit_Assign(self, node): + node = self.generic_visit(node) + + # Only convert lists when they are assigned to a variable, e.g.: + # l = [] + # TODO(mdan): This rule should be improved. + if len(node.targets) != 1: + return node + if not isinstance(node.value, gast.List): + return node + if not isinstance(node.value.ctx, gast.Load): + return node + + if node.value.elts: + node.value = self._pre_populated_list(node.value) + else: + node.value = self._empty_list(node.value) + return node + + +def transform(node, context): + return ListTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py new file mode 100644 index 0000000000000000000000000000000000000000..749ba14347314f975c5a6e1111133336e2f5c5e6 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lists module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import lists +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class ListTest(converter_test_base.TestCase): + + def test_empty_annotated_list(self): + + def test_fn(): + l = [] + utils.set_element_type(l, dtypes.int32) + l.append(1) + return l + + node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = lists.transform(node, self.ctx) + + with self.compiled(node, tensor_array_ops.TensorArray, + dtypes.int32) as result: + # TODO(mdan): Attach these additional modules automatically. + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + self.assertEqual(test_fn(), sess.run(result.test_fn().stack())) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..3a795a315a3c2aa08ac1577a204102755b6e849c --- /dev/null +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for logical expressions. + +e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +# TODO(mdan): Properly extrack boolean ops according to lazy eval rules. +# Note that this isn't completely safe either, because tensors may have control +# dependencies. +# Note that for loops that should be done after the loop was converted to +# tf.while_loop so that the expanded conditionals are properly scoped. + +# Used to signal that an operand is safe for non-lazy evaluation. +SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' + + +class LogicalExpressionTransformer(transformer.Base): + """Converts logical expressions to corresponding TF calls.""" + + def __init__(self, context): + super(LogicalExpressionTransformer, self).__init__(context) + # TODO(mdan): Look into replacing with bitwise operators instead. + # TODO(mdan): Skip replacing if the function is trivial. + self.op_mapping = { + gast.And: 'tf.logical_and', + gast.Eq: 'tf.equal', + gast.Gt: 'tf.greater', + gast.GtE: 'tf.greater_equal', + gast.Lt: 'tf.less', + gast.LtE: 'tf.less_equal', + gast.Not: 'tf.logical_not', + gast.NotEq: 'tf.not_equal', + gast.Or: 'tf.logical_or', + gast.USub: 'tf.negative', + gast.Is: 'autograph_utils.dynamic_is', + gast.IsNot: 'autograph_utils.dynamic_is_not' + } + + def _expect_simple_symbol(self, operand): + if isinstance(operand, gast.Name): + return + if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND): + return + raise NotImplementedError( + 'only simple local variables are supported in logical and compound ' + 'comparison expressions; for example, we support "a or b" but not ' + '"a.x or b"; for a workaround, assign the expression to a local ' + 'variable and use that instead, for example "tmp = a.x", "tmp or b"') + + def _matching_func(self, operator): + op_type = type(operator) + mapped_op = self.op_mapping.get(op_type) + if not mapped_op: + raise NotImplementedError('operator %s is not yet supported' % op_type) + return mapped_op + + def _as_function(self, func_name, args): + template = """ + func_name(args) + """ + replacement = templates.replace_as_expression( + template, func_name=parser.parse_expression(func_name), args=args) + anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) + return replacement + + def visit_Compare(self, node): + node = self.generic_visit(node) + ops_and_comps = list(zip(node.ops, node.comparators)) + left = node.left + op_tree = None + + # Repeated comparisons are converted to conjunctions: + # a < b < c -> a < b and b < c + while ops_and_comps: + op, right = ops_and_comps.pop(0) + binary_comparison = self._as_function( + self._matching_func(op), (left, right)) + if isinstance(left, gast.Name) and isinstance(right, gast.Name): + anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) + if op_tree: + self._expect_simple_symbol(right) + op_tree = self._as_function('tf.logical_and', + (binary_comparison, op_tree)) + else: + op_tree = binary_comparison + left = right + assert op_tree is not None + return op_tree + + def visit_UnaryOp(self, node): + node = self.generic_visit(node) + return self._as_function(self._matching_func(node.op), node.operand) + + def visit_BoolOp(self, node): + node = self.generic_visit(node) + node_values = node.values + right = node.values.pop() + self._expect_simple_symbol(right) + while node_values: + left = node_values.pop() + self._expect_simple_symbol(left) + right = self._as_function(self._matching_func(node.op), (left, right)) + return right + + +def transform(node, context): + return LogicalExpressionTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py similarity index 86% rename from tensorflow/contrib/py2tf/converters/logical_expressions_test.py rename to tensorflow/contrib/autograph/converters/logical_expressions_test.py index a28326c517d468230f35e45f0fbfe5257d769895..2814060c4d831e4dddacb3dcbcbe1db42160db20 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import logical_expressions +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -32,7 +32,7 @@ class GradientsFunctionTest(converter_test_base.TestCase): return a == b node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, self.ctx) with self.compiled(node, math_ops.equal) as result: with self.test_session() as sess: @@ -45,7 +45,7 @@ class GradientsFunctionTest(converter_test_base.TestCase): return (a or b) and (a or b or c) node = self.parse_and_analyze(test_fn, {}) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, self.ctx) with self.compiled(node, math_ops.logical_or, math_ops.logical_and) as result: diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py new file mode 100644 index 0000000000000000000000000000000000000000..dfee529abaa8c14d9b408819b32c5199500a2c2f --- /dev/null +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps a function body with a `name_scope` of the function name.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class FunctionNameScopeTransformer(transformer.Base): + """Wrap a function body with a `name_scope` of the function name.""" + + def _name_for_current_scope(self): + innermost = self.enclosing_entities[-1] + if len(self.enclosing_entities) > 1: + parent = self.enclosing_entities[-2] + if isinstance(parent, gast.ClassDef): + # Methods also take the name of their class. + name = '%s/%s' % (parent.name, innermost.name) + else: + name = innermost.name + else: + name = innermost.name + + # Sanitize the name. + # See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope + # TensorFlow doesn't like leading underscores at the top level. + while name[0] == '_': + name = name[1:] + return name + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + + unscoped_body = [] + scoped_body = node.body + if scoped_body: + first = scoped_body[0] + if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str): + # Skip any docstring. + unscoped_body = scoped_body[:1] + scoped_body = scoped_body[1:] + + template = """ + with tf.name_scope(scope_name): + body + """ + scoped_body = templates.replace( + template, + scope_name=gast.Str(self._name_for_current_scope()), + body=scoped_body) + node.body = unscoped_body + scoped_body + return node + + +def transform(node, context): + return FunctionNameScopeTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..17692cbd880dbc1db4bb40ad7345e27907499f9d --- /dev/null +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for for_canonicalization module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class FunctionNameScopeTransformer(converter_test_base.TestCase): + + def test_basic(self): + + def test_fn(l): + """This should stay here.""" + a = 5 + l += a + return l + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant(1)) + self.assertIn('test_fn/', result_op.op.name) + + self.assertEqual('This should stay here.', result.test_fn.__doc__) + + def test_long_docstring(self): + + def test_fn(l): + """Multi-line docstring. + + Args: + l: A thing. + Returns: + l + """ + return l + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + self.assertIn('Multi-line', result.test_fn.__doc__) + self.assertIn('Returns:', result.test_fn.__doc__) + + def test_nested_functions(self): + + def test_fn(l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('test_fn/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('test_fn/inner_fn/', second_result_input_name) + + def test_method(self): + + class TestClass(object): + + def test_fn(self, l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + # Note that 'TestClass' was needed in the namespace here. + node = self.parse_and_analyze( + TestClass, {'TestClass': TestClass}, owner_type=TestClass) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.TestClass().test_fn(constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('TestClass/test_fn/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('TestClass/test_fn/inner_fn/', second_result_input_name) + + def test_operator(self): + + class TestClass(object): + + def __call__(self, l): + + def inner_fn(i): + return i ** 2 + + l += 4 + return inner_fn(l) + + # Note that 'TestClass' was needed in the namespace here. + node = self.parse_and_analyze( + TestClass.__call__, {'TestClass': TestClass}, owner_type=TestClass) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.__call__(TestClass(), constant_op.constant(1)) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('call__/', first_result_input_name) + self.assertNotIn('inner_fn', first_result_input_name) + self.assertIn('call__/inner_fn/', second_result_input_name) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py similarity index 92% rename from tensorflow/contrib/py2tf/converters/side_effect_guards.py rename to tensorflow/contrib/autograph/converters/side_effect_guards.py index 30976b3ec6db5a6607023ac804d9d54cfb296190..3bcb2d3c42c6e0663c8f78523199a364b6ac231f 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,12 +36,12 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct import templates -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class SymbolNamer(object): @@ -160,8 +160,8 @@ class SideEffectGuardTransformer(transformer.Base): [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ - with py2tf_utils.control_dependency_on_returns(call): - aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args) + with ag__.utils.control_dependency_on_returns(call): + aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, @@ -172,7 +172,7 @@ class SideEffectGuardTransformer(transformer.Base): alias_map = {} template = """ - with py2tf_utils.control_dependency_on_returns(call): + with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py similarity index 97% rename from tensorflow/contrib/py2tf/converters/side_effect_guards_test.py rename to tensorflow/contrib/autograph/converters/side_effect_guards_test.py index 463db2e770213ba9636d2537b095a77dece5d8f6..ce0ce33243a1352107eb8121050ee76474869809 100644 --- a/tensorflow/contrib/py2tf/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -0,0 +1,317 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizes functions with multiple returns to use just one.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# TODO(mdan): Move this logic into transformer_base. +class BodyVisitor(transformer.Base): + """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" + + def __init__(self, context, depth_first=False): + self.depth_first = depth_first + self.changes_made = False + super(BodyVisitor, self).__init__(context) + + def visit_nodelist(self, nodelist): + for node in nodelist: + if isinstance(node, list): + node = self.visit_nodelist(node) + else: + node = self.generic_visit(node) + return nodelist + + def visit_If(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_For(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_While(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_Try(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + node.orelse = self.visit_nodelist(node.orelse) + node.finalbody = self.visit_nodelist(node.finalbody) + for i in range(len(node.handlers)): + node.handlers[i].body = self.visit_nodelist(node.handlers[i].body) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_With(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + if not self.depth_first: + node = self.generic_visit(node) + return node + + def visit_FunctionDef(self, node): + if self.depth_first: + node = self.generic_visit(node) + node.body = self.visit_nodelist(node.body) + self.generic_visit(node) + if not self.depth_first: + node = self.generic_visit(node) + return node + + +class FoldElse(BodyVisitor): + + def visit_nodelist(self, nodelist): + for i in range(len(nodelist)): + node = nodelist[i] + if isinstance(node, gast.If): + true_branch_returns = isinstance(node.body[-1], gast.Return) + false_branch_returns = len(node.orelse) and isinstance( + node.orelse[-1], gast.Return) + # If the last node in the if body is a return, + # then every line after this if statement effectively + # belongs in the else. + if true_branch_returns and not false_branch_returns: + for j in range(i + 1, len(nodelist)): + nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j])) + if nodelist[i + 1:]: + self.changes_made = True + return nodelist[:i + 1] + elif not true_branch_returns and false_branch_returns: + for j in range(i + 1, len(nodelist)): + nodelist[i].body.append(ast_util.copy_clean(nodelist[j])) + if nodelist[i + 1:]: + self.changes_made = True + return nodelist[:i + 1] + elif true_branch_returns and false_branch_returns: + if nodelist[i + 1:]: + raise ValueError( + 'Unreachable code after conditional where both branches return.' + ) + return nodelist + elif isinstance(node, gast.Return) and nodelist[i + 1:]: + raise ValueError( + 'Cannot have statements after a return in the same basic block') + return nodelist + + +def contains_return(node): + for n in gast.walk(node): + if isinstance(n, gast.Return): + return True + return False + + +class LiftReturn(transformer.Base): + """Move return statements out of If and With blocks.""" + + def __init__(self, context): + self.changes_made = False + self.common_return_name = None + super(LiftReturn, self).__init__(context) + + def visit_If(self, node): + # Depth-first traversal of if statements + node = self.generic_visit(node) + + # We check if both branches return, and if so, lift the return out of the + # conditional. We don't enforce that the true and false branches either + # both return or both do not, because FoldElse might move a return + # into a branch after this transform completes. FoldElse and LiftReturn + # are alternately run until the code reaches a fixed point. + true_branch_returns = isinstance(node.body[-1], gast.Return) + false_branch_returns = len(node.orelse) and isinstance( + node.orelse[-1], gast.Return) + if true_branch_returns and false_branch_returns: + node.body[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] + node.orelse[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0] + return_node = templates.replace('return a', a=self.common_return_name)[0] + self.changes_made = True + return [node, return_node] + else: + return node + + def visit_With(self, node): + # Depth-first traversal of syntax + node = self.generic_visit(node) + + # If the with statement returns, lift the return + if isinstance(node.body[-1], gast.Return): + node.body[-1] = templates.replace( + 'a = b', a=self.common_return_name, b=node.body[-1].value)[0] + return_node = templates.replace('return a', a=self.common_return_name)[0] + node = self.generic_visit(node) + self.changes_made = True + return [node, return_node] + else: + return node + + def visit_FunctionDef(self, node): + # Ensure we're doing depth-first traversal + last_return_name = self.common_return_name + body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + referenced_names = body_scope.referenced + self.common_return_name = self.context.namer.new_symbol( + 'return_', referenced_names) + node = self.generic_visit(node) + self.common_return_name = last_return_name + return node + + +class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor): + """Throws an error if code returns inside loops or try/except.""" + + # First, throw an error if we detect a return statement in a loop. + # TODO(alexbw): we need to learn to handle returns inside a loop, + # but don't currently have the TF constructs to do so (need something + # that looks vaguely like a goto). + + def __init__(self): + self.cant_return = False + super(DetectReturnInUnsupportedControlFlow, self).__init__() + + def visit_While(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_For(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Try(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Return(self, node): + if self.cant_return: + raise ValueError( + '`return` statements are not supported in loops. ' + 'Try assigning to a variable in the while loop, and returning ' + 'outside of the loop') + + +class DetectReturnInConditional(gast.NodeVisitor): + """Assert that no return statements are present in conditionals.""" + + def __init__(self): + self.cant_return = False + super(DetectReturnInConditional, self).__init__() + + def visit_If(self, node): + self.cant_return = True + self.generic_visit(node) + self.cant_return = False + + def visit_Return(self, node): + if self.cant_return: + raise ValueError( + 'After transforms, a conditional contained a `return `statement, ' + 'which is not allowed. This is a bug, and should not happen.') + + +class DetectReturnInFunctionDef(gast.NodeVisitor): + + def visit_FunctionDef(self, node): + self.generic_visit(node) + if not contains_return(node): + raise ValueError( + 'Each function definition should contain at least one return.') + + +def transform(node, context): + """Ensure a function has only a single return. + + This transforms an AST node with multiple returns successively into containing + only a single return node. + There are a few restrictions on what we can handle: + - An AST being transformed must contain at least one return. + - No returns allowed in loops. We have to know the type of the return value, + and we currently don't have either a type inference system to discover it, + nor do we have a mechanism for late type binding in TensorFlow. + - After all transformations are finished, a Return node is not allowed inside + control flow. If we were unable to move a return outside of control flow, + this is an error. + + Args: + node: an AST node to transform + context: a context object + + Returns: + new_node: an AST with a single return value + + Raises: + ValueError: if the AST is structured so that we can't perform the + transform. + """ + # Make sure that the function has at least one return statement + # TODO(alexbw): turning off this assertion for now -- + # we need to not require this in e.g. class constructors. + # DetectReturnInFunctionDef().visit(node) + + # Make sure there's no returns in unsupported locations (loops, try/except) + DetectReturnInUnsupportedControlFlow().visit(node) + + while True: + + # Try to lift all returns out of if statements and with blocks + lr = LiftReturn(context) + node = lr.visit(node) + changes_made = lr.changes_made + fe = FoldElse(context) + node = fe.visit(node) + changes_made = changes_made or fe.changes_made + + if not changes_made: + break + + # Make sure we've scrubbed all returns from conditionals + DetectReturnInConditional().visit(node) + + return node diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d483005a09537ea8227814f65aa7e6402c853f60 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -0,0 +1,189 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for single_return module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.python.framework.ops import name_scope +from tensorflow.python.platform import test + + +class SingleReturnTest(converter_test_base.TestCase): + + def compiled_fn(self, test_fn, *args): + node = self.parse_and_analyze(test_fn, {}) + node = single_return.transform(node, self.ctx) + module = self.compiled(node, *args) + return module + + def test_noop(self): + # Noop + def test_fn(x): + return x + + with self.compiled_fn(test_fn) as result: + self.assertEqual(test_fn(2.0), result.test_fn(2.0)) + + def test_return_expression(self): + # ANF + def test_fn(x): + return x * x + + with self.compiled_fn(test_fn) as result: + x = 2 + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_merge(self): + # Simple merge + def test_fn(x): + if x > 0: + return x + else: + return x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_orphan_branch(self): + + def test_fn(x): + if x > 0: + return x + + with self.assertRaises(ValueError): + self.compiled_fn(test_fn) + + def test_lift_body_into_false_branch(self): + + def test_fn(x): + if x > 0: + return x + return x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_lift_body_into_true_branch(self): + + def test_fn(x): + if x < 0: + x *= x + else: + # TODO(alexbw): linter bug here that requires us suppress this warning. + return x # pylint: disable=undefined-loop-variable + return x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_nested_if(self): + + def test_fn(x): + if x > 0: + if x < 5: + return x + else: + return x * x + else: + return x * x * x + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2, 5]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_context_manager(self): + + def test_fn(x): + + with name_scope(''): + return x * x + + with self.compiled_fn(test_fn) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_context_manager_in_conditional(self): + + def test_fn(x): + if x > 0: + with name_scope(''): + return x * x + else: + return x + + with self.compiled_fn(test_fn, name_scope) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def text_conditional_in_context_manager(self): + + def test_fn(x): + with name_scope(''): + if x > 0: + return x * x + else: + return x + + with self.compiled_fn(test_fn) as result: + result.name_scope = name_scope + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_no_return(self): + + def test_fn(x): + x *= x + + with self.compiled_fn(test_fn) as result: + self.assertEqual(test_fn(2), result.test_fn(2)) + + def test_nested_functiondefs(self): + + def test_fn(x): + + def inner_fn(y): + if y > 0: + return y * y + else: + return y + + return inner_fn(x) + + with self.compiled_fn(test_fn) as result: + for x in [-2, 2]: + self.assertEqual(test_fn(x), result.test_fn(x)) + + def test_loop(self): + + def test_fn(x): + for _ in range(10): + return x + return x + + with self.assertRaises(ValueError): + self.compiled_fn(test_fn) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d62390494b78c415212ba91ac914cdfee324f971 --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -0,0 +1,1919 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Dev Summit 2018 - Autograph", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python2", + "display_name": "Python 2" + } + }, + "cells": [ + { + "metadata": { + "id": "g7nGs4mzVUHP", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Experimental: TF Autograph\n", + "**TensorFlow Dev Summit, 2018.**\n", + "\n", + "This interactive notebook demonstrates **autograph**, an experimental source-code transformation library to automatically convert TF.Eager and Python code to TensorFlow graphs.\n", + "\n", + "**Note: this is pre-alpha software!** The notebook works best with Python 2, for now.\n", + "\n", + "> ![alt text](https://lh3.googleusercontent.com/QOvy0clmg7siaVKzwmSPAjicWWNQ0OeyaB16plDjSJMf35WD3vLjF6mz4CGrhSHw60HnlZPJjkyDCBzw5XOI0oBGSewyYw=s688)\n", + "\n", + "### Table of Contents\n", + "1. _Write Eager code that is fast and scalable._\n", + "2. _Case study: complex control flow._\n", + "3. _Case study: training MNIST with Keras._\n", + "4. _Case study: building an RNN._" + ] + }, + { + "metadata": { + "id": "uFcgBENZqkB2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Install TensorFlow; note that Colab notebooks run remotely, on virtual\n", + "# instances provided by Google.\n", + "!pip install -U -q tf-nightly" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Pa2qpEmoVOGe", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import os\n", + "import time\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import six\n", + "\n", + "from google.colab import widgets" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "ZVKfj5ttVkqz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 1. Write Eager code that is fast and scalable\n", + "\n", + "TF.Eager gives you more flexibility while coding, but at the cost of losing the benefits of TensorFlow graphs. For example, Eager does not currently support distributed training, exporting models, and a variety of memory and computation optimizations.\n", + "\n", + "Autograph gives you the best of both worlds: write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported." + ] + }, + { + "metadata": { + "id": "snaZRFdWd9ym", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "For example, autograph can convert a function like this:" + ] + }, + { + "metadata": { + "id": "9__n8cSIeDnD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def g(x):\n", + " if x > 0:\n", + " x = x * x\n", + " else:\n", + " x = 0\n", + " return x" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "gq0eQcuReHET", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "... into a TF graph-building function:" + ] + }, + { + "metadata": { + "id": "sELSn599ePUF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 413 + }, + "outputId": "bb0c7216-1ca3-4da1-d1fb-589902cdcd1a", + "executionInfo": { + "status": "ok", + "timestamp": 1522345737505, + "user_tz": 240, + "elapsed": 243, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "print(autograph.to_code(g))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "from __future__ import print_function\n", + "import tensorflow as tf\n", + "from tensorflow.contrib.autograph.impl import api as autograph_api\n", + "from tensorflow.contrib.autograph import utils as autograph_utils\n", + "\n", + "def tf__g(x):\n", + " with tf.name_scope('g'):\n", + "\n", + " def if_true():\n", + " with tf.name_scope('if_true'):\n", + " x_1, = x,\n", + " x_1 = x_1 * x_1\n", + " return x_1,\n", + "\n", + " def if_false():\n", + " with tf.name_scope('if_false'):\n", + " x_1, = x,\n", + " x_1 = 0\n", + " return x_1,\n", + " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n", + " return x\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "j74n-8hEe6dk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "You can then use the converted function as you would any regular TF op -- you can pass `Tensor` arguments and it will return `Tensor`s:" + ] + }, + { + "metadata": { + "id": "AkVaY0-dfEbH", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "4ffe3757-c44d-424c-c2a8-7ddc973bfcce", + "executionInfo": { + "status": "ok", + "timestamp": 1522345737841, + "user_tz": 240, + "elapsed": 257, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "tf_g = autograph.to_graph(g)\n", + "\n", + "with tf.Graph().as_default(): \n", + "\n", + " g_ops = tf_g(tf.constant(9))\n", + "\n", + " with tf.Session() as sess:\n", + " tf_g_result = sess.run(g_ops)\n", + "\n", + " print('g(9) = %s' % g(9))\n", + " print('tf_g(9) = %s' % tf_g_result)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "g(9) = 81\n", + "tf_g(9) = 81\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "trrHQBM1VnD0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 2. Case study: complex control flow\n", + "\n", + "Autograph can convert a large chunk of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in autograph.\n", + "Autograph will automatically convert most Python control flow statements into their correct graph equivalent.\n", + " " + ] + }, + { + "metadata": { + "id": "u0YG3DPgZxoW", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:" + ] + }, + { + "metadata": { + "id": "xJYDzOcrZ8pI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "6c244ee4-b141-4ad6-eefa-cfffa71f33c6", + "executionInfo": { + "status": "ok", + "timestamp": 1522345738402, + "user_tz": 240, + "elapsed": 483, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def sum_even(numbers):\n", + " s = 0\n", + " for n in numbers:\n", + " if n % 2 > 0:\n", + " continue\n", + " s += n\n", + " return s\n", + "\n", + "\n", + "tf_sum_even = autograph.to_graph(sum_even)\n", + "\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " result = sess.run(tf_sum_even(tf.constant([10, 12, 15, 20])))\n", + "\n", + " print('Sum of even numbers: %s' % result)\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(sum_even))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Sum of even numbers: 42\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "_YXo4KOcbKrn", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Try replacing the `continue` in the above code with `break` -- Autograph supports that as well!" + ] + }, + { + "metadata": { + "id": "xHmC0rBIavW_", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The Python code above is much more readable than the matching graph code. Autograph takes care of tediously converting every piece of Python code into the matching TensorFlow graph version for you, so that you can quickly write maintainable code, but still benefit from the optimizations and deployment benefits of graphs." + ] + }, + { + "metadata": { + "id": "UEHWGpBXbS7g", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. " + ] + }, + { + "metadata": { + "id": "qUU57xlEbauI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "outputId": "add3db4a-2077-4dd5-f7a7-a5b5a4529c26", + "executionInfo": { + "status": "ok", + "timestamp": 1522345738697, + "user_tz": 240, + "elapsed": 253, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def f(x):\n", + " assert x != 0, 'Do not pass zero!'\n", + " return x * x\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " try:\n", + " print(sess.run(tf_f(tf.constant(0))))\n", + " except tf.errors.InvalidArgumentError as e:\n", + " print('Got error message: %s' % e.message)\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(f))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Got error message: assertion failed: [Do not pass zero!]\n", + "\t [[Node: f/Assert/Assert = Assert[T=[DT_STRING], summarize=3, _device=\"/job:localhost/replica:0/task:0/device:CPU:0\"](f/NotEqual, f/Assert/Assert/data_0)]]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "w5hBZaVJbck4", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "You can also use `print` functions in-graph:" + ] + }, + { + "metadata": { + "id": "6NdzRKLEboRv", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "fb82dfc3-790f-4127-87f6-361805be9e9b", + "executionInfo": { + "status": "ok", + "timestamp": 1522345739013, + "user_tz": 240, + "elapsed": 247, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def print_sign(n):\n", + " if n >= 0:\n", + " print(n, 'is positive!')\n", + " else:\n", + " print(n, 'is negative!')\n", + " return n\n", + "\n", + "\n", + "tf_print_sign = autograph.to_graph(print_sign)\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " sess.run(tf_print_sign(tf.constant(1)))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(print_sign))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "1 is positive!\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "9u_Z3i3AivLA", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "We can convert lists to TensorArray, so appending to lists also works, with a few modifications:" + ] + }, + { + "metadata": { + "id": "MjhCQJVuiTNR", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "outputId": "dc320b87-595b-4392-d29c-994486fd8a0a", + "executionInfo": { + "status": "ok", + "timestamp": 1522345744470, + "user_tz": 240, + "elapsed": 5391, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def f(n):\n", + " numbers = []\n", + " # We ask you to tell us about the element dtype.\n", + " autograph.utils.set_element_type(numbers, tf.int32)\n", + " for i in range(n):\n", + " numbers.append(i)\n", + " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + "\n", + "\n", + "tf_f = autograph.to_graph(f)\n", + "with tf.Graph().as_default():\n", + " with tf.Session() as sess:\n", + " print(sess.run(tf_f(tf.constant(5))))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(f))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "[0 1 2 3 4]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "UdG8ZFrkTAF2", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "And all of these functionalities, and more, can be composed into more complicated code:\n" + ] + }, + { + "metadata": { + "id": "DVs6wt8NKaGQ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {} + ], + "base_uri": "https://localhost:8080/", + "height": 53 + }, + "cellView": "code", + "outputId": "0a4b8d08-8f65-4bbc-85ba-dc4c60563519", + "executionInfo": { + "status": "ok", + "timestamp": 1522345745186, + "user_tz": 240, + "elapsed": 658, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def print_primes(n):\n", + " \"\"\"Returns all the prime numbers less than n.\"\"\"\n", + " assert n > 0\n", + " \n", + " primes = []\n", + " autograph.utils.set_element_type(primes, tf.int32)\n", + " for i in range(2, n):\n", + " is_prime = True\n", + " for k in range(2, i):\n", + " if i % k == 0:\n", + " is_prime = False\n", + " break\n", + " if not is_prime:\n", + " continue\n", + " primes.append(i)\n", + " all_primes = primes.stack()\n", + "\n", + " print('The prime numbers less than', n, 'are:')\n", + " print(all_primes)\n", + " return tf.no_op()\n", + "\n", + " \n", + "tf_print_primes = autograph.to_graph(print_primes)\n", + "with tf.Graph().as_default(): \n", + " with tf.Session() as sess:\n", + " n = tf.constant(50)\n", + " sess.run(tf_print_primes(n))\n", + " \n", + "# Uncomment the line below to print the generated graph code\n", + "# print(autograph.to_code(print_primes))" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "The prime numbers less than 50 are:\n", + "[ 2 3 5 7 11 13 17 19 23 29 31 37 41 43 47]\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "JQ8kQT99VqDk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 3. Case study: training MNIST with Keras\n", + "\n", + "As we've seen, writing control flow in Autograph is easy. So running a training loop in graph should be easy as well!\n", + "\n", + "Here, we show an example of such a training loop for a simple Keras model that trains on MNIST." + ] + }, + { + "metadata": { + "id": "0CrtGWgwuLJr", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "import gzip\n", + "import shutil\n", + "\n", + "from six.moves import urllib\n", + "\n", + "\n", + "def download(directory, filename):\n", + " filepath = os.path.join(directory, filename)\n", + " if tf.gfile.Exists(filepath):\n", + " return filepath\n", + " if not tf.gfile.Exists(directory):\n", + " tf.gfile.MakeDirs(directory)\n", + " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n", + " zipped_filepath = filepath + '.gz'\n", + " print('Downloading %s to %s' % (url, zipped_filepath))\n", + " urllib.request.urlretrieve(url, zipped_filepath)\n", + " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n", + " shutil.copyfileobj(f_in, f_out)\n", + " os.remove(zipped_filepath)\n", + " return filepath\n", + "\n", + "\n", + "def dataset(directory, images_file, labels_file):\n", + " images_file = download(directory, images_file)\n", + " labels_file = download(directory, labels_file)\n", + "\n", + " def decode_image(image):\n", + " # Normalize from [0, 255] to [0.0, 1.0]\n", + " image = tf.decode_raw(image, tf.uint8)\n", + " image = tf.cast(image, tf.float32)\n", + " image = tf.reshape(image, [784])\n", + " return image / 255.0\n", + "\n", + " def decode_label(label):\n", + " label = tf.decode_raw(label, tf.uint8)\n", + " label = tf.reshape(label, [])\n", + " return tf.to_int32(label)\n", + "\n", + " images = tf.data.FixedLengthRecordDataset(\n", + " images_file, 28 * 28, header_bytes=16).map(decode_image)\n", + " labels = tf.data.FixedLengthRecordDataset(\n", + " labels_file, 1, header_bytes=8).map(decode_label)\n", + " return tf.data.Dataset.zip((images, labels))\n", + "\n", + "\n", + "def mnist_train(directory):\n", + " return dataset(directory, 'train-images-idx3-ubyte',\n", + " 'train-labels-idx1-ubyte')\n", + "\n", + "def mnist_test(directory):\n", + " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "2zu1U9Nqir6L", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "First, we'll define a small three-layer neural network using the Keras API" + ] + }, + { + "metadata": { + "id": "x_MU13boiok2", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def mlp_model(input_shape):\n", + " model = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n", + " tf.keras.layers.Dense(100, activation='relu'),\n", + " tf.keras.layers.Dense(10, activation='softmax')])\n", + " model.build()\n", + " return model" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Wuqg3H8mi0Xj", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Let's connect the model definition (here abbreviated as `m`) to a loss function, so that we can train our model." + ] + }, + { + "metadata": { + "id": "W51sfbONiz_5", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def predict(m, x, y):\n", + " y_p = m(x)\n", + " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n", + " l = tf.reduce_mean(losses)\n", + " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n", + " accuracy = tf.reduce_mean(accuracies)\n", + " return l, accuracy" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "035tNWQki9tr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Now the final piece of the problem specification (before loading data, and clicking everything together) is backpropagating the loss through the model, and optimizing the weights using the gradient." + ] + }, + { + "metadata": { + "id": "CsAD0ajbi9iZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def fit(m, x, y, opt):\n", + " l, accuracy = predict(m, x, y)\n", + " opt.minimize(l)\n", + " return l, accuracy" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "PcVRIacKjSwb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "These are some utility functions to download data and generate batches for training" + ] + }, + { + "metadata": { + "id": "RVw57HdTjPzi", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def setup_mnist_data(is_training, hp, batch_size):\n", + " if is_training:\n", + " ds = mnist_train('/tmp/autograph_mnist_data')\n", + " ds = ds.shuffle(batch_size * 10)\n", + " else:\n", + " ds = mnist_test('/tmp/autograph_mnist_data')\n", + " ds = ds.repeat()\n", + " ds = ds.batch(batch_size)\n", + " return ds\n", + "\n", + "def get_next_batch(ds):\n", + " itr = ds.make_one_shot_iterator()\n", + " image, label = itr.get_next()\n", + " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n", + " y = tf.one_hot(tf.squeeze(label), 10)\n", + " return x, y" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "2zEJH5XNjgFz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This function specifies the main training loop. We instantiate the model (using the code above), instantiate an optimizer (here we'll use SGD with momentum, nothing too fancy), and we'll instantiate some lists to keep track of training and test loss and accuracy over time.\n", + "\n", + "In the loop inside this function, we'll grab a batch of data, apply an update to the weights of our model to improve its performance, and then record its current training loss and accuracy. Every so often, we'll log some information about training as well." + ] + }, + { + "metadata": { + "id": "UUI0566FjZPx", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(train_ds, test_ds, hp):\n", + " m = mlp_model((28 * 28,))\n", + " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", + " train_losses = []\n", + " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n", + " test_losses = []\n", + " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n", + " train_accuracies = []\n", + " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n", + " tf.float32)\n", + " test_accuracies = []\n", + " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n", + " tf.float32)\n", + " i = tf.constant(0)\n", + " while i < hp.max_steps:\n", + " train_x, train_y = get_next_batch(train_ds)\n", + " test_x, test_y = get_next_batch(test_ds)\n", + " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n", + " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n", + " if i % (hp.max_steps // 10) == 0:\n", + " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n", + " step_test_loss, 'train accuracy:', step_train_accuracy,\n", + " 'test accuracy:', step_test_accuracy)\n", + " train_losses.append(step_train_loss)\n", + " test_losses.append(step_test_loss)\n", + " train_accuracies.append(step_train_accuracy)\n", + " test_accuracies.append(step_test_accuracy)\n", + " i += 1\n", + " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", + " test_accuracies.stack())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "cYiUQ1ppkHzk", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Everything is ready to go, let's train the model and plot its performance!" + ] + }, + { + "metadata": { + "id": "K1m8TwOKjdNd", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {}, + {}, + {} + ], + "base_uri": "https://localhost:8080/", + "height": 988 + }, + "outputId": "f9d3eef3-5bea-45c1-ddf9-4edee73e4436", + "executionInfo": { + "status": "ok", + "timestamp": 1522345800262, + "user_tz": 240, + "elapsed": 52391, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "with tf.Graph().as_default():\n", + " hp = tf.contrib.training.HParams(\n", + " learning_rate=0.05,\n", + " max_steps=500,\n", + " )\n", + " train_ds = setup_mnist_data(True, hp, 50)\n", + " test_ds = setup_mnist_data(False, hp, 1000)\n", + " tf_train = autograph.to_graph(train)\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = tf_train(train_ds, test_ds, hp)\n", + "\n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " (train_losses, test_losses, train_accuracies,\n", + " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n", + " test_accuracies])\n", + " plt.title('MNIST train/test losses')\n", + " plt.plot(train_losses, label='train loss')\n", + " plt.plot(test_losses, label='test loss')\n", + " plt.legend()\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Loss')\n", + " plt.show()\n", + " plt.title('MNIST train/test accuracies')\n", + " plt.plot(train_accuracies, label='train accuracy')\n", + " plt.plot(test_accuracies, label='test accuracy')\n", + " plt.legend(loc='lower right')\n", + " plt.xlabel('Training step')\n", + " plt.ylabel('Accuracy')\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/train-images-idx3-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/train-labels-idx1-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/autograph_mnist_data/t10k-images-idx3-ubyte.gz\n", + "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/autograph_mnist_data/t10k-labels-idx1-ubyte.gz\n", + "Step 0 train loss: 2.244329 test loss: 2.2499208 train accuracy: 0.12 test accuracy: 0.161\n", + "Step 50 train loss: 0.64771986 test loss: 0.56013924 train accuracy: 0.82 test accuracy: 0.836\n", + "Step 100 train loss: 0.49011207 test loss: 0.42143965 train accuracy: 0.84 test accuracy: 0.879\n", + "Step 150 train loss: 0.3768609 test loss: 0.39319593 train accuracy: 0.88 test accuracy: 0.883\n", + "Step 200 train loss: 0.36007702 test loss: 0.37089333 train accuracy: 0.9 test accuracy: 0.881\n", + "Step 250 train loss: 0.182115 test loss: 0.28543878 train accuracy: 0.94 test accuracy: 0.915\n", + "Step 300 train loss: 0.2119576 test loss: 0.22305593 train accuracy: 0.92 test accuracy: 0.93\n", + "Step 350 train loss: 0.12932214 test loss: 0.29057172 train accuracy: 0.96 test accuracy: 0.906\n", + "Step 400 train loss: 0.22937602 test loss: 0.2200287 train accuracy: 0.92 test accuracy: 0.925\n", + "Step 450 train loss: 0.23444137 test loss: 0.19857481 train accuracy: 0.94 test accuracy: 0.94\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3XmAFNW9Pvynlt5mYdhmQMHggnGN\nS9zCD0ElKug1edUY9ZoQTYze3GuiRk1uYjRqRHNj4n5NrhKjiUYlbihGQFRUFDSoKIvgICAO6+xL\n711V5/2jlq7qZaZnpnumZ3g+/zjTXV1dXSP91PecU+dIQggBIiIiGjLkwT4AIiIi6h2GNxER0RDD\n8CYiIhpiGN5ERERDDMObiIhoiGF4ExERDTEMb6JeOOigg3DllVdmPf6rX/0KBx10kGe766+/3rPN\ne++9h9mzZwMAtm3bhkMPPdR57osvvsCPfvQjzJw5EzNnzsTZZ5+NV199FQBw0003YdasWZg1axYO\nO+wwnHLKKc7v4XDY8x7JZBLz58/v9edavXo1Lr300oK2XbBgAebMmdPn97J19/rZs2fjhRde6PO+\niYY7hjdRL3366aee0Ewmk1izZk3WditXrsQnn3xS0D6vu+46TJs2DYsXL8bixYtxyy234LrrrsPO\nnTtxyy23YNGiRVi0aBHGjRuH3//+987vVVVVnv188sknfQrUI444Ag8//HBB2y5fvhxTpkzp83vZ\n+vt6oj0Zw5uol0444QQsWbLE+f3tt9/GV77ylaztrrnmGtx+++0F7bO+vh5HHnmk8/uRRx6JxYsX\nY/z48QUfV3NzM3784x/jo48+wkUXXQTAbAF48MEHMXPmTOi6jlWrVuHcc8/FrFmzcOaZZ2L58uUA\nzFaB0047DQBw//334ze/+Q2uuOIKfP3rX8d5552HxsZG533ee+89HHzwwVnv9cEHH+Bb3/oWTjvt\nNJx//vloaGgAAOzevRsXX3wxzjzzTJx66qm4++67cx5rPu+99x7OOecczJo1C9/+9redC6Vc++3u\ncSEE/vd//xczZ87EKaecgjlz5kDXdQDAwoULcdZZZ+GMM87AN77xDbz33nsFn3eiwcDwJuqlM844\nAy+99JLz+z//+U/MmjUr53ZCCCxatKjHfU6fPh1XXnkl/va3v2HTpk0AgHHjxkGSpIKPa+zYsbjm\nmmtw1FFH4YknnnAeF0Jg8eLFUBQFv/71r3HppZdi0aJFuPzyy3HTTTfl3NeiRYtw/fXX49VXX8WY\nMWPw7LPPAgA2bdqE2tpaTJgwwfNe4XAY//mf/4lrrrkGS5Yswfe+9z1cddVVAIBHH30Uxx13HF5+\n+WUsWLAADQ0NMAwj57FmikQiuOqqq3DDDTdg0aJF+OEPf4jrrrsOhmHk3G9jY2Pex1944QUsWrQI\nzzzzDJYsWYKGhgY8+eSTAIBbbrkFDz74IBYuXIibbroJr7/+esHnnWgwMLyJeun444/Hxo0b0dLS\nglgshlWrVmHKlCk5t73++uvxhz/8AYlEott9/v73v8d3vvMdLFiwAGeddRZmzJjhBEt/nXzyyc7P\n8+fPxxlnnAEAOOaYY5zqONOxxx6LCRMmQJIkHHLIIdi5cycAYMWKFTk/6wcffIBx48Zh6tSpAICz\nzjoLX3zxBXbs2IExY8bg7bffxvvvvw+/34+77roLdXV1BR376tWrMX78eBxzzDEAgJkzZ6KtrQ3b\nt2/Pu998jy9duhTf+ta3UF1dDVVV8e1vfxuvvPIKAGDMmDF46qmnsH37dhx77LH45S9/WdjJJRok\n6mAfANFQoygKTj/9dCxcuBCjR4/GiSeeCFXN/U/psMMOw3HHHYdHHnkERx99dN59BgIBXHrppbj0\n0kvR2dmJRYsW4fbbb8fEiRMxbdq0fh3vyJEjnZ8XLFiAv/3tb4hEIjAMA/mWNqiurnZ+VhTFaV5+\n5513cMkll2Rt39nZiYaGBk8LhN/vR2trKy655BIYhoFbbrkFjY2N+M53voOf/OQnBR17a2srRowY\nkXVsLS0tefeb7/Guri48/PDDmDdvHgBA13WMHj0aAPCnP/0Jf/rTn3Duuedir732wvXXX4/jjz++\noGMkGgwMb6I+OPPMM3H33Xdj1KhRPfbZ/vSnP8W5556LiRMn5ny+tbUV69evd6rWESNG4Pzzz8ey\nZctQX1/f7/C27d69GzfccAOefvppHHLIIfj8888xc+bMgl+vaRrWrFmT8yKkrq4O+++/P5577rmc\nr7388stx+eWXY8uWLbjsssucSronY8aMQXt7u/O7EAIdHR0YM2YMVFXNud+pU6fmfLyurg4zZszA\nd7/73az3+dKXvoTf/va3MAwD8+fPx7XXXotly5YVeGaIBh6bzYn64Oijj0ZjYyM2btzYY4VWV1eH\n73znO7j//vtzPh+Px3HllVd6wmLr1q34+OOPceyxx/bquFRVRTgczllRt7a2oqKiAvvvvz80TXMq\n0EgkUtC+V69ejYMOOgh+vz/rvY488kg0NTXh448/BgA0NDTgZz/7GYQQ+PWvf4133nkHgBmSY8eO\nhSRJ3R6r7YgjjkBzczNWrVoFwBxfMH78eEycODHvfvM9/vWvfx0vvPACYrEYAOCpp57C888/j9bW\nVnz/+99HOByGLMs48sgjezXWgGgwsPIm6gNJknDaaachFotBlnu+Bv7BD36Ap59+Oudze++9N/70\npz/hvvvuw5w5cyCEQFVVFX75y196RqAX4phjjsEf/vAHTJs2DW+++abnuYMPPhjTp0/HzJkzMWbM\nGPziF7/Ahx9+iNmzZ+O///u/e9y3fYtYvve67777cOuttyISicDn8+Gqq66CJEm48MIL8etf/xq3\n3norhBCYMWMGpkyZgh07dnheryhK1ntWVFTgnnvuwa233opoNIrRo0fjrrvu6na/I0eOzPk4AGzc\nuBHnnHMOADPYb7vtNowePRrTpk3Dt771LSiKAp/Ph9tuu61X551ooElcz5uIiGhoYbM5ERHREMPw\nJiIiGmIY3kREREMMw5uIiGiIYXgTERENMUPmVrGmpq6i7m/UqAq0tUWLus89Ec9j//Ec9h/PYXHw\nPPZfsc9hbW11zsf32MpbVbPvKaXe43nsP57D/uM5LA6ex/4bqHO4x4Y3ERHRUMXwJiIiGmIY3kRE\nREMMw5uIiGiIYXgTERENMQxvIiKiIYbhTURENMQwvImIaNh6443XCt723nvvxI4d23vc7sMP38cN\nN/y8P4fVbwxvIiIalnbu3IFXX11c8PZXXXUt9t57QgmPqHiGzPSoREREvXHXXb/D+vXr8Mgjc2EY\nBnbs2I6dO3fgnnv+iN/+9jdoampELBbDD35wOaZOnYYf//hyXHPNz7F06WuIRML44out2L59G668\n8lpMmTI153u89toSzJv3dyiKgoMOOgS33XYL6us34M47fwefzwe/349bbvktdu7cnvVYdXXuqU8L\nsceGd0c4gfc3NOLYg+sG+1CIiIa9f7z+GVZuaCzqPo87uA7nz5ic9/l///fZeO65f+D7378MDz/8\nIDQthT/+8c9oa2vF8cd/DWeccRa2b9+GG2/8BaZOneZ5bWPjbvzhD/fh3XeX44UXns0Z3tFoFA89\n9AAeeeQJVFRU4Oc//yneffddvPzyyzjnnPMwa9a/4YMPVqK1tQUvv7wg6zGGdx9ceecbaO2M46ZL\njsOk8X0/gURENDQccshhAIDq6hFYv34dXnzxOUiSjM7OjqxtjzjiKABAXV0dwuFwzv01NHyBiRO/\nhIqKCgDA0Ucfg/Xr1+PEE0/CH/7wP2ho+AJf//ppmDRp35yP9cceGd5b23YiPOFNSMnD0dwRZ3gT\nEZXY+TMmd1slDwSfzwcAWLJkETo7O/HAA39GZ2cnfvjD2VnbKkp6gREhRM79SZL3OU1LQZJCOPbY\n4/HnP/8Ny5cvw5w5N+PHP74652Nf/eqxff4se2R4f7ztCyjVbTBG70RLZ3ywD4eIiEpAlmXoup71\neHt7O/baa2/Isow333wdqVSqT/vfZ59J2LbtC0SjEVRUVGLVqg9x1VU/xrPPzsOUKSfi9NPPgBAC\n9fUbsGXLpqzHGN69dPykA7G4CZArO9DSwfAmIhqOJk3aD59+ugH33XcnKiurnMdPPnkGfvGLa/DJ\nJ2vxb//2TdTV1eGRR+b2ev+hUAhXXHEVrr32J5AkGUcccRSOPfZY7NzZghtv/AWqqqrg8/lw/fU3\nob7+06zH+kMS+doDykxTU1dR93fjit+ipTOCQyLn4yfnHlHUfe9Jamuri/632dPwHPYfz2Fx8Dz2\nX7HPYW1t7m7dPfY+7y+P2Q+SL4mmcOtgHwoREVGv7LHhPbFmPACgLdk2yEdCRETUO3tseI8JjQIA\nxBFGPKkN8tEQEREVbs8N74rRAADJH+egNSIiGlL22PAeW2FW3pI/xtvFiIhoSNljw3uME96svImI\naGjZY8M75AvCLwcg+eNoZuVNRDQs9WZJUNtHH32ItjbvnUjlsAyo2x4b3gAwMlDDypuIaJjq7ZKg\ntn/+88Ws8C43e+QMa7a6ijFojDWiqSt7UnoiIhra3EuCXnDBRbj99lvQ1dUFXddx9dU/w+TJB+Lx\nxx/Fm28uhSzLmDp1Gg455FAsW/YGtmzZjDlz7sD48eOz9pu5DOjVV1/nLANaWRkCIJdkGVC3PTy8\nxwItQKfWPtiHQkQ0rD332UtY1bimqPs8uu4rOHfyWXmfdy8J+uijf8YJJ/w/fOMbZ2PLls24994/\n4J57/oinnnoc8+cvgqIomD//WRx33NcwefKXcc01P88Z3LmWAf3ww/fx1ltLcc4552H27AuxaNHr\nJVkG1G2PDu/a0FgAQAysvImIhrM1a1ajvb0Nixe/DABIJMzu0pNP/jquvvq/cNpps3D66bN63E+u\nZUDr6zc4S362tOzClCknlWQZULc9OrzrKszwTildMISALEmDfERERMPTuZPP6rZKLjWfT8VPf/oz\nHH64dy2L6677JbZu/Ryvv74EP/nJf+Chh/7a7X5yLQMaCAScJT/XrFlZsmVA3fboAWt25Y1gFNE4\nZ1kjIhpO3EuCHnro4XjrrTcAAFu2bMZTTz2OcDiMRx6Zi0mT9sX3v38ZqqtrEI1G8i4lCniXAQWA\nVas+xEEHHYpnn52Hzs4OfPOb38QFF1yE+voNzmOnn36G81ix7NGV96hgDSQhQw5EEYmnUBXyDfYh\nERFRkbiXBP3hD3+E2267Gf/1Xz+EYRi4+urrUFVVhfb2Nlx22fcQClXg8MOPwIgRNTjqqK/ihhv+\nG7/97Z3Yf/8DPPvMtQzokUcehVgsihtv/AVGjaoBIJdkGVC3PXZJUHvZtp++fgvicQM/O+pa7L/3\niKK+x56ASwj2H89h//EcFgfPY/9xSdABEpCCkNQUIvHUYB8KERFRQfb48A4qIUiqhs4oJ2ohIqKh\nYY8P7wo1BABoj4YH+UiIiIgKs8eHd6XPvFevIxEZ5CMhIiIqzB4f3iMClQCAjjjDm4iIhoY9PrxH\nVZgj+Xa1tw3ykRARERVmjw/v0RXm7WE7OjrQHk4M8tEQERH1bI8P70qfOWBNUpNYvallkI+GiIio\nZwxvn9nnDSWFpvbY4B4MERFRAUo6Peodd9yBDz74AJqm4T/+4z9w+umnO88tX74cd911FxRFwfTp\n03HFFVeU8lDysm8Vk9QUWjvZbE5EROWvZOH97rvvYuPGjZg3bx7a2tpwzjnneMJ7zpw5ePjhhzFu\n3Dh897vfxcyZMzF58uRSHU5eITVo/qBoaOviRC1ERFT+Shbexx13HI44wlx6bcSIEYjFYtB1HYqi\noKGhATU1Ndhrr70AACeddBJWrFgxKOHtV/wAAJ9foK2NlTcREZW/koW3oijOYuXPPPMMpk+fDkVR\nAABNTU0YPXq0s+3o0aPR0NDQ7f5GjaqAqipFPcba2mqM1M3K2+8XaI8kMXZsFSSu690r+SbOp8Lx\nHPYfz2Fx8Dz230Ccw5IvCfrqq6/imWeewV/+8pd+7aetLVqkIzLZK78IISBLMiTFQCKpY+u2NlQG\nuTRoobgKUf/xHPYfz2Fx8Dz237BYVWzZsmX4v//7P8ydOxfV1ekDqKurQ3Nzs/P77t27UVdXV8pD\nyUuSJPhlP2TVXHi9jYPWiIiozJUsvLu6unDHHXfgwQcfxMiRIz3PTZw4EeFwGNu2bYOmaVi6dCmm\nTp1aqkPpkV/xAbIZ3h2R5KAdBxERUSFK1mz+8ssvo62tDVdffbXz2AknnICDDjoIp512Gm6++WZc\ne+21AIAzzzwT++23X6kOpUd+xY9kyhxpHo5xXW8iIipvJQvvCy64ABdccEHe54877jjMmzevVG/f\nKwHFjw6YS4IyvImIqNzt8TOsAYBf9kMXZmhHGN5ERFTmGN4w+7wNGIBksPImIqKyx/BGeqIWyDrC\ncYY3ERGVN4Y3zD5vAGZ4s/ImIqIyx/CG2ecNAKrPYJ83ERGVPYY3rPu8AYRCEitvIiIqewxvpPu8\nQyEgHNMG+WiIiIi6x/BGus87GABiCQ26YQzyEREREeXH8Ea68g5YS3tH46y+iYiofDG8Afhls89b\nVc2KO57UB/NwiIiIusXwRrryllUBwGw6JyIiKlcMbwABJQAAzrKgrLyJiKicMbwBhFQzvCXVrLjj\nSVbeRERUvhjeAIKKNVJNNu/xjiVYeRMRUflieAMIWpW3IZkVd4yVNxERlTGGN4CQGgIAGJJZecdZ\neRMRURljeAMIWgPWdCQBsM+biIjKG8MbgCqrUCQFmhXe7PMmIqJyxvAGIEkSgmoAKWGFNytvIiIq\nYwxvS1AJImkkAABxTtJCRERljOFtCaoBJHQrvDlJCxERlTGGtyWkBpHQk1BkNpsTEVF5Y3hbgkoQ\nAgKBoOCtYkREVNYY3hZ7opZgCIiyz5uIiMoYw9sSVM0pUitCQCSWGuSjISIiyo/hbQlZ85sHQwJJ\nzUAixaZzIiIqTwxvi115B4IGAFbfRERUvhjeFrvP2+c3wzvM8CYiojLF8LbYzeYqw5uIiMocw9ti\nV96yz+zrZngTEVG5YnhbglblLavmbWIMbyIiKlcMb4tdeUNheBMRUXljeFtC1mhzQzJDm+FNRETl\niuFtCTK8iYhoiGB4W+w+b3tNb85vTkRE5YrhbfHJKmRJdtb0TunGIB8RERFRbgxviyRJCClBZ01v\nneFNRERliuHtElQDiGlxKLLEypuIiMoWw9slqAYR1xJQFRmaJgb7cIiIiHJieLsErWZzRQE0Vt5E\nRFSmGN4uITUAAQHVLxjeRERUthjeLj7Fb/5XNRjeRERUthjeLn7ZBwCQVYGUzj5vIiIqTwxvF5+s\nAgAU1YCm9b/ybutK4MEX16G5I9bvfREREdkY3i4+xay8FaU4fd5PvFqP9z7Zjb8u3NDvfREREdkY\n3i4+u9ncZ0ArQrN5PKl7/ktERFQMDG8Xu89bUQwYQsAw2O9NRETlh+HtYjebS4rZZM5Z1oiIqBwx\nvF2c0eayGdq8XYyIiMoRw9vF7vO2K+9i9HsTEREVG8PbxWk2tyvvItwuRkREVGwlDe/6+nqceuqp\nePzxx7OemzFjBi666CLMnj0bs2fPxu7du0t5KAWxK2/I5ujwfjebC1buRERUfGqpdhyNRnHrrbdi\nypQpebeZO3cuKisrS3UIvebPCG8OWCMionJUssrb7/dj7ty5qKurK9VbFF1Ws3mxwlsqzm6IiIiA\nElbeqqpCVbvf/U033YTt27fjmGOOwbXXXgtJGtyUs6dHFZLdbM5mbyIiKj8lC++eXHnllZg2bRpq\nampwxRVXYPHixZg1a1be7UeNqoCqKkU9htraas/vcf9IAIBqLi6Gqqpg1ja94fObp9enKv3aT7kb\nzp9toPAc9h/PYXHwPPbfQJzDQQvvs88+2/l5+vTpqK+v7za829qiRX3/2tpqNDV1eR4Lx1IAgJSW\nBAA0t4TRVBPo83ukkpq1Pz3rvYaLXOeReofnsP94DouD57H/in0O810IDMqtYl1dXbj00kuRTJoh\nuXLlShx44IGDcSge9mhzQ+KANSIiKl8lq7zXrl2L3/3ud9i+fTtUVcXixYsxY8YMTJw4Eaeddhqm\nT5+OCy64AIFAAIceemi3VfdA8St2n7dZMevs8yYiojJUsvA+/PDD8dhjj+V9/uKLL8bFF19cqrfv\nE6fyBitvIiIqX5xhzUWRFEiQYMCsvDnDGhERlSOGt4skSfApPqfy7uk+7x3hXXjsk38grsUH4vCI\niIgADOJo83Lll33QhTVKvIc+7/s+eghdyTDGVdTi9H1PGYjDIyIiYuWdKagEkDQSAAC9m8p7W2MY\nXckwACBpJAfk2IiIiACGd5bairGIGREEv/oqtic3593ulfcbnJ8lzn9KREQDiOGdYXyFORe7pGpY\nrb2af0N3i/ogT+tKRER7FoZ3hnGV6YVUVPjzbifAe8CJiGhwMLwzjK+oTf8iCquoZTabExHRAGJ4\nZxhfOc75OYEINEPLvaGn8GZ4ExHRwGF4Z6j2V+EHX/4h9I4xgCTQGm/r8TXs8iYiooHE8M5h/5pJ\nMLpGAQCaYq05t/GMV8tTebNXnIiISoHhnYOqSBApc7BaLJV7KVLhSmbeKkZERAOJ4Z2DqsiAYU4+\nl8g7AYsnvYmIiAYMwzsHVZEhDAUAkNBzh3chzeZERESlwPDOQVUkQDfDO5knvHuD4U5ERMXE8M5B\nkiQo1pot21o6cm8kvNsTERENFIZ3Hgp8AICV9TuwsyWS9TxHkhMR0WBheOdhhzdkHZ2R7pvO2SxO\nREQDieGdhyqlwzsX4bpXzBD5lw4lIiIqNoZ3HnZ4S0qe6VFd3EFORERUagzvPHyKz5yIRdaR1Lqv\nrA2w8iYiooHD8M5DlRXAUCApOpKp7KZzd7HNZnMiIhpIDO88fKp1r7esI5nqofJmszkREQ0ghnce\n5ixrKiRFQ0LLUXm7f2blTUREA4jhnYeqyAVX3nqe8GZBTkREpcDwzkOWJXN+c1lHIpljxLn7VjEO\nWCMiogHE8M7DMIQ5YE0WSGiprOe9zeY9lNicw4WIiIqI4Z2HYQhAN+c3j2mJ7rdlnzcREQ0ghnce\nuiGcZUFjqXj2Bp5bxdi5TUREA6eg8F67di2WLl0KALj77rtx8cUX4/333y/pgQ023RAQyQAAIKKH\nu92Wfd5ERDSQCgrvOXPmYL/99sP777+PNWvW4MYbb8R9991X6mMbVIYhIBIhAEBMdGU9z1vFiIho\nsBQU3oFAAPvuuy9ee+01nH/++Zg8eTJkeXi3uJuVdzfh7VmYhM3mREQ0cApK4FgshoULF+LVV1/F\niSeeiPb2dnR2dpb62AaVIQREMggASCLXet7pwK7f1pZzxDkXLCEiolIoKLyvueYaLFiwAD/96U9R\nVVWFxx57DJdcckmJD21w6a5m86Scq8873VS+uy2CpvZY9hZ2djPDiYioiNRCNvra176Gww8/HFVV\nVWhubsaUKVPw1a9+tdTHNqgMwwAMFUJToSvRrOfdlTckkQ5q9zZW5c0KnIiIiqmgyvvWW2/FwoUL\n0d7ejgsvvBCPP/44br755hIf2uD60rhqAIBIhKCrkawAzry3O3ezub1taY6RiIj2TAWF9yeffIJv\nf/vbWLhwIc455xzcc8892Lp1a6mPbVBdcsbB+N7Mg+DTqwFZR0cyo49fSieyxMqbiIgGUEHhbYfP\nG2+8gRkzZgAAkslk6Y6qDFQGfTj56AkIiBEAgMZok+d54b63WxI5VyGxA53ZTURExVRQeO+33344\n88wzEYlEcMghh2D+/Pmoqakp9bGVhZAwP+fOsDe8vROziJwBzcqbiIhKoaABa3PmzEF9fT0OOOAA\nAMDkyZNxxx13lPTAykW1MgotALZ37fY8nll557rXO+lrARSJfd5ERFRUBYV3PB7H66+/jnvvvReS\nJOGoo47C5MmTS31sZWGkbzSAXM3mmaPNvQm9qf1ztO31OvwVYyBaTy71YRIR0R6koGbzG2+8EeFw\nGBdeeCHOP/98NDc344Ybbij1sZWFmmAVhACi1uIkH3zaiBfe3gJkNJvruje869s2AQCUmhb2eRMR\nUVEVVHk3Nzfjrrvucn4/5ZRTMHv27JIdVDmpCKpAVIJm6ACAB55fCwA4cLLrukcS6EqGcdt7D+Oc\nyf+GQ8cchNZ4KwBApHzs8yYioqIqeHrUWCw9g1g0GkUi0f0a18NFZVAFhAxN1z2Pp9y/S8DqjlXY\nEdmFBz5+GADQEm8DAIhkiH3eRERUVAVV3hdccAHOOOMMHH744QCAdevW4aqrrirpgZWLiqAPEBI0\n4Q3vpK65fhMQGQndaoe3prLyJiKioioovM877zxMnToV69atgyRJuPHGG/HYY4+V+tjKgll5S9AN\n74xqKS0d3lLGgDUhhFN5QzYY3kREVFQFhTcA7LXXXthrr72c31evXl2SAyo3duWti+6azYUnoBN6\n0pk+VZJ1DlgjIqKi6vOi3HtKNVkZVCGEbC5U4pIy3GEunAFtABDX4+mnFH2POVdERDQw+hzekiQV\n8zjKVkVQBSBBhze8tYzKO2GkB/DFtHR4S7LOAWtERFRU3Tabn3TSSTlDWgiBtra2kh1UOamw+ryF\n8PZdp3QNAdd2yTzhzT5vIiIqtm7D+4knnhio4yhbiixDEjIMpKC5J2KR3TOsGXkrb7DPm4iIiqzb\n8J4wYcJAHUdZkyUJAgZSWrqpXJJdt4pJQMod3qmoazsDBpjeRERUPH3u8y5EfX09Tj31VDz++ONZ\nzy1fvhznnXceLrjgAjzwwAOlPIx+kyADEEhprn5v1R3eAkmRDu+2RIfn9ULSQEREVCwlC+9oNIpb\nb70VU6ZMyfn8nDlzcP/99+PJJ5/EO++8g88++6xUh9JviiRDSAaSrvD2VN4QSBnp9c2d8DbM0ysk\n721mRERE/VGy8Pb7/Zg7dy7q6uqynmtoaEBNTQ322msvyLKMk046CStWrCjVofSbLCkABOJJVwgr\n3klaUsIV3vF28wctCAAQYOVNRETFU7LwVlUVwWAw53NNTU0YPXq08/vo0aPR1NSUc9tyoMgyJFmg\nI5JuGpdUb+Wtwd1sboV3yhwf+2w3AAAgAElEQVSPzsqbiIiKqeAZ1gbbqFEVUFWlqPusra0uaDtV\nMU+TUFzXOlblLTQVkj/puQu8PWk1m1vhDVkv+L2GouH82QYKz2H/8RwWB89j/w3EORyU8K6rq0Nz\nc7Pz++7du3M2r7u1tUW7fb63amur0dTUVdC2spABCdi2M31vu2SHt+5zqvB9qvZGQ3gHuhJh87lU\nABIAQ9IKfq+hpjfnkXLjOew/nsPi4Hnsv2Kfw3wXAiUdbZ7PxIkTEQ6HsW3bNmiahqVLl2Lq1KmD\ncSgFUWTzNHVE0/3akDXz/m3NvP6pwlhMm+AdnCecZnP2eRMRUfGUrPJeu3Ytfve732H79u1QVRWL\nFy/GjBkzMHHiRJx22mm4+eabce211wIAzjzzTOy3336lOpR+UxUF0IHOqGvaU1UDdBWQzHu4fQhA\nldOn0y/7ENet5nb2eRMRURGVLLwPP/zwbpcNPe644zBv3rxSvX1RqbIV3rH0oDQoGoSuOn3fighA\nkdN98j7Fh6iumE0bDG8iIiqiQWk2H2pUK5S7oq7R5opZeUtOePuhSunwViU1fZ+3zGZzIiIqHoZ3\nAXyKGcrhuN3nLZzK2x6spghvs7kqqxCaFeYyK28iIioehncB/NatYl0xK7xlA5IkzD5v2A/5PM3m\nqqxA6NbvisaVxYiIqGgY3gXwWfeX68KqoJ3bxNzh7Tebyi2qpMIwrN9lnUuTEBFR0TC8C1Dh91k/\nmRHszGvuCm9J+Jy+cQBmFW5V3pJVeXdGk7j/2dXY1hgekOMmIqLhieFdAL/PCm/JmkdNsSpwwzXj\nm65mNJur6cpcMdf0/ufyrVi1sRn3Pbt6AI6aiIiGK4Z3ARTJOk2SgCSlK293szkMNavZ3A53STYr\nb3s98GSKA9iIiKjvGN4FUOxbwCSByqAPvoDVg+2qvCXd22yuuprNoegw2OlNRERFMmQWJhlMslV5\nS5JAZciHsGrAACB0BYmNR0EZ2QRZqvbcKqZIKgAZQpedypuIiKgYWHkXIN1sbuDEr4yHL2D1fRsq\njLbxSG35CoRhB7b9GsXZxu7zdkjSwBw4ERENSwzvAshWEP/7qZNx5tcmweczk9i5jxuArouMZnPV\n2UbKvM+bVTgREfUDw7sAduU9fmwIkiRB8dmVtyu8hchoNndV3jL7vImIqHgY3gWQrSVBDWGGtqxa\no8Vdo811XXjmNreb0IWuAIoGwzDSO2SzORER9QPDuwB2Fa1b4S1Z93m7m80NQzgD2wCkg1xXIUlA\nyuDiJEREVBwcbV4AO5S/6NyGz9o3A0rKfMJwVd6GAclVUctOs7n537iWXguceuetj3dgQm0lDti7\nZrAPhYioLDC8C2D3eS/e+joAQLZWGfMMWDPSk7CYr/Fuc+fqe3AELhqQ4x1OYgkNjy7cAAD4yy9m\nDPLREBGVBzabF0B29WUDgJCyJ2nRDYHHXql3frfDW1LNKj2hJyBghntnJIkHnlsDg6POe6TpRs8b\nERHtYRjeBVAk72kSMMy7vQxvn/fGhnbXa8zntN2T0tsgXZl/UN+Enc2REh0xERENZwzvAmSGNwAr\nuNN93PGk7unztkebG51jobWMN3+Gd05znfeP9YhniIgoG8O7ALKsZD+oe4cLRGIpz+8KXK+xKnQD\n3hHnDO+esWeBiCgbw7sAco7KWxgZ/eAwB1c5r5HdK45Z94lL3srbYHj3iOdoePvX+t247I6l2N0a\nHexDIRpSGN4FyN9s7hV2Vd+K69TaQS/YbN5rDO/h7c8vfQLdEFi2eudgHwrRkMLwLkDmaHMATjXt\n5g7j5vZk+glhbtssbYLkT1cYKY6k7hFH5A9v/PMS9Q3DuwC5Ku9RVaFuX7P4vW3pX6yg362uQ/Co\nt5yHUymGd08Y3kRE2RjeBcjV5z1+VBWqQj4AQCjQw1w3OZrYAVbehWCzORFRNoZ3AZQczeaqrDrL\nfFZX+Lp9vcjRxA4AyZSe83FKY3jvGbhWD1HvMLwLIOf4ZnEv/1kdyg7vQyeNxrUXHoWvHTouu/KW\nNQACb3e8jHe2vwcAWPDOFsxd8ElRj3s4YHYTEWVjeBcgksq+jcW9/Gd1hT/r+ZHVfhy272jzOeE9\nzZI/DskfxxfJT/HEp88CAJ5ftgUr1u0q8pEPnLWbW7BibfGPn5U3EVE2hncBJo3YBwDw5VGTncfM\nZnPz5xGV6fAWmlmFj6kYCQCQ5exmcykQg+RPrzJmrxMOwGmKz2f+ss34+LPmPnyK0rrrHx9j7kvF\nbznggDUiomwM7wJU+6vwwIw7cOa+pzqPqa5Z1/yq7Axei6+ZisTGo3DUhP0BABUBNavZXArEIAVi\nzu+NkXQYdxdWndEkXnznc9z7zOr+faAS6unio7fKObxffGeLs+IZEdFAYnj3guIKbFVWPfNu71NX\nBQCoCYzAdbPOwKTx1QCAiqAv655wSUlB8qfD+4GP/+KsEa7p3YR3JJn3uXKR1Io7gr6cm83nL9uC\ntz7eMdiHMaSV8bUZUVljePeCu59blVQ4y2ZIwMGTRgEAamtCzs8AUBlSs/q8IQlP5d2aaIW692YA\ngN7N7WPhaCrvc+Wi2CPoyzm8iYgGC8O7F7Iqbye7JZxxwpfwzan74rJvHOp5TWXQlzUPOmTDCe+v\n7zPd3F9tAyBrWZV3Uk9i4ZZX0Z7oQGe0/CvvRLHD23U6/ufvH6KxPZZ/40HCC4y+4y1iRH3D8O4F\n9/3e7j5vSQJURcbZ0/ZH7UjvzGuVOZrNIRmQ/HEoIoBzDzwLB4a+AknVIPnj0DIq73/Uv4CXtryC\nFzctQke4/MM7WeRZ49x93vUN7Zj32sai7r8YONlO37HZnKhvGN69oHbT551PZUjN7vOWDEi+JFQj\nCAAwdOt5SUBzVXHtiQ6s2LnS+b0jo897xY6VWLBpUS8/RWkVu/IWGVVtOS7mknnBRURUagzvXvBU\n3pKCQtK7MuiDENnN5lBSkI0AAEDT0o+7+7zvWzU3/RJJdgashQLm/h7f8DQWbX0dulE+M7UVu887\nM6zLsYk6VeRBekREPWF490L2aHMzSLrrtzNvFcucpCUBSQIk3QzvlDUOTZIM6Fafd2ckieZYi/Oa\nqBZzKu+qjBndolr59AMnSthsnuv3cqAxvPuNfd9EvcPw7gXPaHO5h8VILLIsZd/n7TMnaJE0c3IX\n3S5WJQOaYQbBtX98C7rQceDIAwAAsVQMHZEEAMCvKp77qbuS4d5/mCJyH0vxR5tn/l4e4a27Dox9\n3n0nCup8IqJMDO9eUFyBrcqq606xHsoG4X1e8pshbM/GJgzreUkgkdTxxqrt0GWzyq5UK+BX/Ihq\nMcQTZjDqhkBcT8/QFklF+vyZisHdtF30Pu/MyrtMwlvT0sfR3b35ROXglZUNWL+1bbAPg4qI4d0L\n7nW9PQPWemzyywhvnxnMwqq8Dd16Xjbw7Fub8bfFn0JSzI5wvxxEhRpCTIs5FZ4hBLqS6cAO55h7\nfSC5w9tdeacMDXd/+Ces2LEy18sKkt1s3uddFZW72uaANSpniZSOp17biN8/uWqwD4WKiOHdC1kD\n1iw9ZXfdqFDOx42kWXk74S0Z2N5kNoFLqtkRHrDCO6rFnYFRhiEQTrnDe5Arb1d4ufu8t3Y24LP2\nLXh8w9N933eZjjZ3BzYHrFE5K5fWKiouhncvSK5RNYprkpae3Pz94zC+60QkPj3G83gqYTbD233e\nkiTgXApY06X65QBCaghxLY6UZm5oCOFpKh/sZnMtT+WtGVquzXsl84unXAasuQepsc+7H6w/Z+bY\nBiLqHsO7j3yyAvf0qN0J+lVMCnwZRkcthKv/OxaVYRgCuqvytkmqGXw+KYAKXxACAklh9pUbRmaz\n+WBX3q4+by0d3rmWUu2tzLDOvO97sLgvWDjavP/K5aJsOOK5HZ4Y3n2UOT1qTxTFOtVGeluR8iMS\nT0HX0n3e6RdYlbcUQIVaYT1vPmYIIJxKjzAPJ7NDMvMfbCKl4911u5zqvZjczebJZPrnLtcxLnrv\niz7tO/N7Rx/gLyJNN7BkZQOice+88u7AZp93/7Fpt3TKpauJiovh3UfmwiSmQu5R9dnh7VqkROgq\nWjsTcPLUGm0OpPu8VRFASA1ab2pW45l93pnN5l/s7sIPf7cUb3603XnsuTc346EFn2D+si0FfT63\ndVta8doH2/I+7xlt7ro4CLtuYVv4Xu/fF8jRbD7AX0TPv7UZT762EX9fUu953N1Uzmbz/mN4l065\ntFZRcTG8+6jQ+7xtimIlvHuFMV3BLY+uRDhid3obTsVsh7cCPypUc8CbZFXjhiE8TdI7Irs8s6wt\nX7sLAPDU6585jzU0dgEANu3o7NVxA8Cd8z7C35fU521+y9fn7b7/3JD7tiJa1gxrA/w9tHFbBwCg\ntTPheZwD1orD/nOyabd0mN3DE8O7j1RZ6dWiCnblLQz7vxKc028FuiS5m83NKlsRfgTUgPVYesBa\nXDPv8z669itoT3Rgbct656V2FSP7EqhvMwPcp5qj4/vTbJ6vOvKMNk+6wtvVIiAUb/gV/J5Z93kP\nbFC2dZnHPbI64Hnc22wuEI2nsO7z1gE9tuGEAVM6bNUYnhjefaRIhU2PalNVO6itjQ1X5S7Sk7TY\n7CpbNgLpJnopfatYzArv0yedAgB4a9sK57VOv/A+a3DvqofwcdNa+K33T/ajSszXt5tvkhZ35d3X\n8M5s8hvoUcntYfO4R1T4PY+nXIP0NM3AnfM+xp1PfYT6hvYBPb6hzv6nM9AXZW5CCCz9cBt2tQ7u\nfAmlwlaN4Ynh3UeqrOLkoycAAA760qgCtvc2m8vCHd7Wn8E1YE3yJyAMCbLwwWc10UtyepKWmBaD\nLBQ89sIuHDhyf2xo24hdkd3m7uzAC5lBMn/Ty/D5zPdI9WPu8XwDX9yjzd39v+5BdYbct+VMM9/S\nEAKabmTNvFYq9mfOfD8tY5KWLTvN7oimMlxvvJw5zeaD2POweWcnHnulHr+a++7gHUQJsfIenhje\nfeSTFXzntC/jjh9NwWH7ju5xe6fytprNZeSqvN3hHYNIhqDrrv512Wo2N4CYFofQfdi8vRMnjDfv\nH/+0bZP5vN1vnjJHqTdGmyGpZgWZ7EezuZ5jGlAhhGeeb/e0oRHXKHhD6Wt4Zw9Yu/z3b2DO3z7o\n0/56w90FkDkoTcszYI0LbPRNb6vD9zc04sEX1xWlqozEzC6q4VqgsvIenhjefaTKKmRJwtiRuWdP\ny9o+Y7S5e7S63Q9uN5srqg7Jn4RIhKDpBnyKtYqY5K6844BuTtEaUioBAM+/vRHN7bF0FaOkB4nF\nVXOFMvfgqriW7hMvRGbl3RZvx/ee+ylWtb0PqGY4u0MtYbgCW+1bs3n2gDXzd7vSLaXWrvT88cmM\nFgv3eXT/LGekdyyhIZbo/2Q1A03TDWzd1TVg79fb6vCP89fivU92Y3cBTd2vf7itV5/FMATunPeR\n526NoYyV9/DE8O4j91SphVCt0eb2JC2q5FrW0x6wZjWLT5xo7lskQkhpRlazOWCGt6GZj8swt4+m\n4nhx+efpK20lHRqblXcANenp835iwzO4d9VD+KhpbUGfQc+oPjd1fI6ElsCyliUIHvkG4Is74W0I\nA5qhQYHVV9zHyjuzz3sgFwGx108Huq+8NU/l7Q3vK+5+C1fc/VaJjrB0HlrwCW55dOWA9eH3tTp0\n5k/Io7E9hsdfqcctj+afXz+ztaSxPYZ1W1rx10Wf9umYyg2ze3gqaXjffvvtuOCCC3DhhRdi9erV\nnudmzJiBiy66CLNnz8bs2bOxe/fuUh5K0Vz+le/hrP1metb2LoRdeUtWda1KKn71PWu61IxmcyVo\n9puKRAU03Ug3m9vN6rIBXegwUnZ4p5/fvKPTudIWcgp1FWMBAElEoY773FMl2iPUN1rN7T3pbrIH\nSTGgjGx0giypm8EXRJV1Avp2q1jSSHpaEIq95Gh3uqLp982cRU3zDFhzDTTsY7N5S6wVN7xzOza0\nbuzbDors/Q2NAFBQZVsMfa0OMy8oMxXy/0vmn2y4dX2w8h6eShbe//rXv7B161bMmzcPt912G267\n7basbebOnYvHHnsMjz32GMaNG1eqQymqI2sPxxn7fb3Xr3Oaza3wliDjgL1roMiS0w/ujDb3m1+Y\nduWdsjPErrytMBO6VZFb64VLio4dzRGzipEMQNYxOjAKFx30LfP5jACt9JnN7YVOr6plfAnYI95t\nysgmZxR2QrdmiDMqrffuW+W9UnseoWNegz20qbezRdW3fYZH1j2BVB/mWe+KuirvjLECqTxzm/e1\ngnz1izfRlmjH3DWPFbS9EAJPLKnHui2lvT2tKuTreaMi6Gu+9HSPfUE5LHX765DHPu/hqWThvWLF\nCpx66qkAgAMOOAAdHR0Ih8M9vGr4UuzR5s7tZVbftyJD2KPN7T5t1aq8k0GkdAML3ramFrUGrNnL\nhcIKbwjF83wkrjkBH1KDOGj0gZ7nbVU+c0BbOJk7vNvi7Xh03ZOQrIuJzCrHvtf8lJFnw4hXQK5u\ncyrUlNXfrYgghC47y6D2VhhmOMk1zX16/b2rHsL7uz/CxwV2Dbh1uirvzJDwDNJzN6FrfWz+tVpy\nNFHYRcb2pghe/WAb7pz3UZ/erzvukfWZF2yl4q4OOyNJrN3ckndbz/H11I3ShzJ6uGXdnlh5G0Lg\nd3//EP9c8flgH0rJ9G6asF5obm7GYYcd5vw+evRoNDU1oaqqynnspptuwvbt23HMMcfg2muvzeov\ndBs1qgKq2rum6p7U1lYXdX/dGdlsNT/azeaKitraavhUGYmUt887FJKBlFlZ+/wqWtpTwCjXJC5W\neAvdrIpGjrA+hxXOmiGchU1GVY/AXnXWrWzW/u3PXREIAl1AzIjmPBfLPnkbK3evQuAIGfH3T0f1\niJBnu+QXZrhVh6ogkgHIwSg0w0BtbTVi7eaAMkX2QST9gJrM+R7vbVuFUcEafHns/p7Ho/EUKoLp\nqk8Zux1GR61nm978/Xyh3v+9U64vPUOSPK/3B9LHJrv6XYMVfmc7dytBT+8dCJj/FDVDK+g4O+Lp\nC7Fi/3/c1pluUQm5Pk8pqT7FeZ9fPLQEja1R3Hftydhv75qsbSOx9EVVVXXQeV2u44y7rrnyfY6a\n1phnm5he+N9tKGgKpy+cC/k8w+Ezh6NJfNrQjk8b2nHJN78y4O8/IP9mSv4Olsz7ZK+88kpMmzYN\nNTU1uOKKK7B48WLMmjUr7+vb2orb91ZbW42mpoEbTdvVZX1BWAEqdKCpqQuyLOWYpMX6YhYyOrvi\nUCUFCddrneZva8BaW6s5ktsO/65Iup9Y0hR0tVnPWzO0NTZ2QpIkdMbMintXuAlf7GyCX/Z5+vI7\nwzFnv1IwjJaWCJpC5nsmkjpefOdTqOOARFQ4rQApI4mmpi7s6jAHOukpCdD8kIKRrPNtCAN3vvMQ\nAOCBGXc4j2/Y2oY7nlyF807e32yokAClphkpyfBML9ubv19LR5dn+22NYUgSMKG2yrPdZ9s7cO/T\nH+Pq849EY0u6RSIWT3le3+EKuIireb2tPepsF0+mq+jujvWtxmVY9Nkbzu/23ydTMqXj3U9247iD\n69DWnv73UKz/j1es24VdLVEctl/61sfWtuiA/DuJu85vo9XPvnFLC6p82Y2Dja576Ztawmiq9uf9\n99zSkm7ty/c5OjLOZVNzz68ZSlpb0/8f9/R5Bvp7sVQiroWEBvrzFPsc5rsQKFmzeV1dHZqb002d\njY2NqK1NV05nn302xowZA1VVMX36dNTX1+fazbDh3ELk6vMGkNHnbQ1YU+1FjmVougG/Yo3Ylg0E\nfIrTbG73ecPwNpvHk5qzTYUagk/2eZ5/7q3NeGfNTqfPOqEncd1bv8b1Cx/CZ9s7nGN2z58uBSOe\npuKuWNJpAZAMn3MsQtagG4bTbA5dgdB8kBQdCc3bdJ7Qc98+ttIaLLVw5RanA1JSNchVfR/5HE15\nJ0/59V/+hRsf/lfWds8s/QyRuIZnlm5yms1HVPg8zea6YeCL3el/nO4+b/e98O6R/d01Xc5bu8Dz\ne0TLfaG6YPnneHThBjz56saSNO3OXfAJFiz/HM0d6XM1UPO25+qXzddk7668e1qOtZAxEpnvM9xW\n4ervx+mIJPHBp03FOZgBMtz+hrmULLynTp2KxYsXAwDWrVuHuro6p8m8q6sLl156KZJJ88t85cqV\nOPDAA0t1KGUhff+vNe847D5vKV1NWuEtK1Z1bshIaQYCavo+74BPTt8CZjWb6xrM6t0K51hCd6rz\nkBqCIiuQhAzJev6fK7bi4X+uzxpwFg5twd3/SPehulcFU0Y1oiWRHhxlGMK5QJCFz6m8JUWDpgkk\nrNHmwlAgUubFR1vcezUaTaXf372wiv1FLqvmY0I3L07kEd5+UPMiQcsK5lw6k7mvhA1hYHc0/cVk\nT6aj6Qa6oklUhXwI+BVPiK1YuxtrXQPFtDyD19yz2el5phAzRPbjLbHcg9B2tZih/vmu0t7j3tKR\n/ruUMrzdrXG5Lm7yjST3hHcPo80LGayV+d7D7YvffQ76MjPh//z9Qzzw/JohNfVvrgmlhpuShfdX\nv/pVHHbYYbjwwgsxZ84c3HTTTXjuueewZMkSVFdXY/r06c5tZKNHj+62yXw4kK0Ba3a/tV15T6yt\nAmA1nctWVW6FNwwFKV0gqKbv8w74FSek7cldNN2AJBSn2Tye0JyAt5cTlaB41wuHQEJPoNJeKxyA\n0FTPVbq78lZrt+PvDQ8imdJR39BuTlpih7er8oaiIaUbSOr2iHgZ0Mzwrm/x3pIW1dKh61772/4y\ntS8OjC6zGVcOeQc8aprAw2sfw8+W3YRIKuoJccMQePHtLc5kOJ3J3IG3cMur+M27v3fudbfvCtB0\nga5oCtUVPvhUb3hv3tHh2UfKM2DNtba5PUJd0pHQcg9Ei2vZrQ/5Rv/b/w/phijpl1OLq0uglMud\nunMkZ3jnCdGwK7x7Or5CgjgrvHvY519eXo///r/lPe63XLg/X19Gntu3Cw6lqX/zXSwPJyXt877u\nuus8vx988MHOzxdffDEuvvjiUr59WXEqbzugrdHml5xxMPbfewdeiSsw7AFp9n+FDE0zYOj23Oe6\n2Wwup8MdsKoj4Qp1pG/NspcTlYXqHW2uaBAQ2K9mknO/t4hXpudgR+4Q+cvL6/Gv9Y046/9NgqRo\nELoC3YCn8tZ1A0nDWr5UV2AkzGOYt+kZHD/hSAStVdJirvDuSHRiZMAcnGR/v8jWoDsRr4DQFUhB\n7/GkdANrms1j//mymwEAPzjsIhwz7ij8a/1uzH97C0LHCEABOhPp4HdXa+/sMCfvWLV7DY6qPdwJ\n70RKRziWwoSxlYgndU94B/2utdxlAy0j/gUlPgJSMIIW3Q/AHHyXTBmAZCB45FuYv6kT3z3sW1nn\nM7P1AzAHreVi37FgGKLHirM/8lXeiZSOrmgSY2sKm1WwJ+4gyZWx+YI3Ek+fn54uYgoZaZ35Pj0F\n/turd1rbGVDk8p/nyhPehkAP89rkNZRaJIbSsfZV+f+fN0yMqQlaP3mbzasr/Pi3KftClc1wkkc2\nIpwKQ7Kq8ZRuIJGy/keUzD5vJ9ytyjulGea93q5wloLm1fLY0Bjzd6E4zeZAetWyoBLElV/5sfmg\nrHtmrAqnIhgd9C668q/1Zn/0xoYOc1CcrkI3BISRWXlbzea6DL3xSxBJM7Cjrv5cd3gv3rrUqdad\nudntixFdhYhXWp/JfZtQdoDtipjHZ37BC2cZ1Q5X5e2e6tReS317s9msbs+E12atJmZW3rInxGLW\nQLTbLjsBoTFtSFR/Dv8Bq+GbsAmrxItOU3hS0yH545D8CWzsyD0NbVw3gzKoBHDKPicCQM570qOp\nmPP31Q0BrciVhbs5tdm1drm7JeGOJ1bh539a4Zl5rj/0HirCfBcoUdd0sz1V3oWFd+ZtgIV98Wd2\nKQgh8PFnzWU3Ha773PYn1IbSLWdsNqeiGVUdwG2XneBUywq8k18okgJJ1RD48ofYEdllzaomIaUZ\nSKYEhCFDkg34fa6QFq5lPo2McA5EASFhTMhscpZyVN4AsGZjJ3738GdQjRAgG051J4RAOBVBlTWR\nS6akpluVt2p++en2RDEaNF044W1oMiBk6O3mYEU7oAHvILKPm9bipS3mGIms6V11FUas0hz17k+/\nRtMMyJL3f2FP8Lmmh+1IdDrv51621J4AJ2m9zl533V6UpLrSD58qw3AtwBK3ngv61Zwz7dmfMakZ\nkHxmOLfEWz2f3WZX3idNnIqJVXtnfwaYs9XNee9ObAq+CkBYK6sV98vJHUSeytsVjvZ88u5m6/5w\nh0GuUMn3GfU83RQ5t+1L5a27jyv//jOPb/naXbj3mdX466INPb7nQHJ/hP5c8w2lanYoXWj0FcN7\nAO01phIHpk6D3joOU+r+n+e5zBHGftkHVTFHmyeSulllywZURXaazYWr2VwYCiRfylkARA5GIeuh\n9LzoIqMyt5qk7XlzzIsD3ak8E3oSmqGhUq1AcrN5n6Q9hzoAJDTdDEddhaYJT5/3X15ej664GZS6\nZq+mZr7WDnXAW3kDwEeNa8xNnT5vb+UNAFIo3XSe0DSnYjyq1jzGlGEHp56ezAaAgMD7uz92nks/\nYU9ba430z2hTHFFhhjeQDji7sgr6FShq9rdh0khCNwxsbwxD8iec998VzZ4C2J7oJqQGnb9VKiPk\nP2hcjY5kJ8LyLsjVbVafd3Er77hrBTXPimnWZ3avsFasL3F3tZ85h735Prk/o2dq2j40my98dytW\nb2rJu02+VfIyZVbe9iDGTdtLv2hOb3gGBvbjNoVi/z9XSkPpQqOvGN4D7IpZU3H1cT/AtMO+1O12\nqqzCp5qVdyJlhndNtYq6kaH0wDO72Vw3zIFhAEJfXQqoCUj+BJRUFYQQ+OPzaxCNCUiygNPsbM8X\n7p5iVU734dn93RVqJTsCVIYAACAASURBVPTmCdA7R8OADsBuEk5Bks3Qjic1T5/3Z9s6sGqTGVS6\nZlXyVmVu94UDQDSjv7cl3obGaHO6/9OpvBWIlNns7p5mNZyMQkDgyNrD8c39Z5rnwqpaw7GU83q9\nrQ5CAG82LEdjW8QTRPY99kKyBt9l3F49wmo2B9Jf1vGEBglAwK84I+LdknoKTy/dhKde/wzwpZug\nd4R3ZW1rV95BNQjVuqVPM7zhvXLXh87PytjtVp934V9OWzq+wJ/XPo5wMpJ3tHE8zxzg9mduaDSv\n8tQJG/HytpcKfu/u9NRsnm+kuztce2w2z9hvNK7h6Tc24Z6nP05v002fd+b+3ecvc8rcZmtA11in\ni6w86D20cBS8nyE09Vyxu5XKEcN7gPl9Cg6eNKrb2eQAwCer8Cmy1WyuQ5FVCDkJQ04BkvWl4fR5\n604VDgDKSPP+eilZiVhCw/ufNmXdCy75zdCwQ1FYfeZ25b0zYgbNCL81QYDzeiu8EXFeH4lrnsob\nSFfYuqZ4Xp9wVd6fN5mVynXH/Bhn7GtOpdueaE9/Qcqu+9l17/EDQJc1rWuVr8IJPrvyjsRS6dHq\nsSroLXthV2wXfvXss3jh7S3pE23tLwnzizfzy7raVXl/vH0zVu/YjFhSRzCgQJYkyEqu8E5imTWo\nya68ge7DO6QE0pV3RrN5S7wN1b4qyEKFXNmRNWBtV2Q33tuZf33zf9TPx6rG1bjx5b/i9sdzbxfP\n009rn4+GRnNMgG/CJqxu/xDhWApPLKnvVxO6O0dyZUrmMqw276IwvWs2D8ey++u7azbPvIBwH1Pm\nc01Wd8OoEYFuj2mgGT3cklfwfoZQNbsn9HkP2Axr1DuqYt5fHEtqSGoGAlAQTnXhXflvgGwu4iKE\nq9lcl5wFFeRqs0lQSoUQtkfmWkHv+9IGpD4/DFLAXrnMHDls6GZzvH070spdq8ztIxMAtDmVM2Qd\nMFTEpU4oMEeoR42Uq/K2mrqtCwwtZVW2OZrN12zdBXUsMMJf5dzSFtPi6S8J2T52NT0JnSss7TnZ\nK32V8Cne4AvHNE+fubbjAKhjd0Ie0YpVG9OTB9n3w8eMMO54/37ElNEA9nKe19ROrK94BsrY/fBk\nwyLz/RJnOyPOJVflbUSrIFeEkdCTqAgoiCU0p88bAHZEssM77qq884V3OBnBmNAoSMlKdIR2Q0fK\nEzi3vncnAKAmMAL71UxCwJ7Uxz4uawBdomIbNm34ctYxAN5mczc7HKMJLf33ADD3pbVYs6kNmiHw\nvZkH5XxtT/L1ecuSBEOIvCuCefq8ezlgzT1ffa73BrxVW+b+3bPmuS/0hBDOQL5yC7nM0eZ9NZQC\nkc3mNGh8soqAT0VXxPyySfc3Cydw7C/7ZMqAcC2bKVeYVZIwZIStLys7PNW6bZBrmiFb4W3Ezfu8\nDatvWlENGMLA6uZPEMIIvPCKNWGIYc+/bo149pnNqCJe4am81boGyKN2Oc3Qeko2mxFzhLd7xLs7\nvJ2Kxj2TnD0Lnavytu9Dr/RVOLPI2f3F4VjK2b/QVAgtPdGNh7WNhhS2djag0f+x5+k2YxeSUgT+\n/dMLm8STOkKBdDcBACTWHwe9dbzzGYP281blXaFU5q68dbvPO2Teiw/vrWIpQ0Ncj6PaV4VqqRaS\nBBiBzpxNyvd/NBd3rLwPQgi8/uE2ayY2gdZ4m3ksqubchZApka/Z3AooTRee127aaf5/YfSjedId\nJO4+b7v1J6nluaDoplk7U+aXeFeOkfLdNptrmeGt53zOfftavhYDIQSefXOTZxbDgeAZbd6Ppu+B\nWqSmGPaE+7wZ3mXi9Emn4OBRBzr3OvtkFUG/4vzDE3L6S8eerGREyAy8aELzNM/KlWZ4G7qcbtbU\nXaOiJQEpEDOraWsCFfteclk2oBk6UkYKWjQEZ35Su9ncqnxl64vciFeiPZxIr3AGwDfhMwgrZFMp\nCaOqA1AlMzyT1rSpumFO8iKEOUNb0BXedsUl5HSfd2azPwBENKvyViucCxk7+CKxFKCmK3e4lk1N\nnwcjPSFOPkp2c3IslUDQnx5dDwBC8zvvsaO1AxV2ePviEJoPtYFx6Eh2eia+AbwD1pZ+YIb79tb0\ngCd7lrsqfyWqYN72JwKdeQcP7Yo2YmP7Jjz+Sj2WvN+AjmSXZzIcuTJ7MFVcS2Bly/KsVeeAdEBp\nugHZdZ99NGVdlAT7vmSonmcglWKHd54Q9FbehQ9YE0KgM9q7ZvPsyts1sM89IY/r4iffRcfnu7rw\nzxVbcftj+bs4SsGd17kGBvbEPb/AUFGs1oZyxvAuE//fAWfgJ0dfZt0iBqiy2WxuS0npL2A7qGsq\nrfCOa5B82TN1CUN2+vjcfeKSrEMKRK0mc8nZFgBk1XACUNcl1768fd72hCkiXoH2cDIdrjCb0u3K\nW+gKVEXGiKDZPG/fLhWOpszPofmh6cKpvONaHAnNACCQ8rUDwgxGu9nefTucHUqVvgrzVjtIzoC4\ncCzlDG4TKX+Oyl3Af9BKz2fPJQVr/vdPj4HeYYYnanYh4LcvqlyD6qxz8MTrG8zKXElBCkZhRKsw\n2mfeKpdZfcdc4b16o1khN7anJ5SxZ56r9lVBMsygFJLebRW0rvlT5+ftHebAwUlV5gBJKZQ9Teyz\nG1/EB13L4Nvn06zn7PBKaYZ5+6HFvmjpzz3NIk+zuT1oMpmnP7uvfd66Yc6alykz4LuvvHM3m7tb\nLqJGFz5sXJ31Pok8XROl1t8Ba3Z4r9ncUvKpeYuluwuw7vx10QZnEp5yx/AuM4p137JPVhH0eatl\nN2HIGGmFdyyhQW8dl7Uvs/K2vmxEOoilQBSSqkEkXTNluSpbu8/VHinuft4OTykQMydesSZnCfpV\nnDzCmkFMNiAk3ZqaVIJPlVFTYb5XJGmGVWs4BikQgxGvQDJleJrNkykdck0z9EA7KhP7mHO4Z1T+\nABC1Ku8qfyUkSYJPVhFJxLHgnS1WeNvN5n4AMoQhpcPfl4AywgxLvWW8s09ZT48U3m+vaucWPpEM\nQsTMufn9B6xG20irerIH1bmqe8jmjGxyVbvZzN01GmN85t9nc8fnnr+RHd6bGiLOHPbuPm+7X7/K\nX+lZtz39hZT9ZRxOpPvZd0fN/v0vjzBnN5QrO7NGnO+KmhPb5Ap2O7xSuuFtclfSa8c/99YmvPbB\ntqzX9iTfaPN05e0Nu2ff3ISXln+OsNTsdH/0ps9b13NX3lpGuOVbqx0AYnmazd2tBLvHvYSH1z6O\n19au97z2b1/8H/wHZy+GUyhDCM/FQ8Gv62cVav89GhrD+M2j72PTjg7c98xqRON9v3ArplxjI/py\nwZJM6Xjzox34y8vre964DDC8y4w9baoqKZ7KO4sho6bKbPKOJjSkPj8MX459A0YsPamK0CWn8naW\nEQUgWc3u9qxn9v4As9nZvlXJXXlnjVZXXCPMAVQEVewd2MfaRoMOzemHVhUZFX4zFCNWsOzobIYk\nCYh4BZKajpCSEd6VZr9gZXQ/81hzNJt3psyFEsYEzYrYp/iwszWM55dtMf/B+lyVN2BeaNjH75rn\nXa4I43jpAnNbpB8/cOJIp5lbaH7rIsDU5WswH5fTg+Ls1gF17834rKUBcrV5cWCER2Kczzw3G9q8\nM63t6mqF0GX88dkNCPnM/bv7vO1b9qp9Va7WAyNdfarZYbR5d5vzc7u1GEyNOhpGIgg51JX1ZaZI\n9oWZAXWfDVDqvjB/l1zN5prhad2RrM8djafw0vKt+PuSeuyM7Ma8T5/POV97Lkae+7ztqYTXbmnF\nR5+lBxf+c8VWvLBuGT6vfhnq3uY8+T32eXtmFzN6rLx1XXQ72txTeWu5K2+7p+mJ1z51LpQMYaAj\n1QZlRO5FZwpx3zOr8V93vdVta4emG57lMM337t993plTwP7+yVX46LNmvPnR9l7vq9iefXMTfnTn\nm9jZ4p06ubtBh24bt7Wjrcv8/zVfyLd0xEs6HXFfMbzLjF15GxCe8E6PJbcYCmoqrfCOpwChoEau\n9TRf64bkDFjzfPFat4l5mrqtn9uqV6fvv3Y1J2eFp6x7Xl8RUOFTFXMOckWH4QlvCRX/P3vfGW9H\nVa/9TN/19H5OzknvIR0SEjpEulIFiShYLyI2BEQR9PpD5aJX5d5XQbHAtYAIypULWABpIXRIg5De\nc0pO3XXKej+sMmv2npOQkJAE5vlAOHvKXrNm9jzr356/wfTMWax0xyDt5EUKSdiOhxjTYM+5eRSY\nJjgdgxEYq+w2H3D7YGkmKkxqERuqESB33eR9z/k51OD4GZztI2F5FehIjxDu/mRMxylzRvgxascA\nsf34rgGLndIG8VhnODZGNZaFPvlpkZvgZSphKnG0pVqwrm+DiGl7xENPoQskT5vT8LwDuXXqoBTz\nlhdQPO6rsAXKmMqRmBc/EwCwTYqZ9xeY7CuJg2QroJhF9GT7AxKnQo42MQijeQN0Rt5xU4fteMg5\nebylPwa10idSYXnnfCL58Su348mtS/D0tucwHAghIg8j2DDD30f+/Cf3Bd3PWgO18DU2Fsfx8Pra\nnmFL1uRzOR4JlXYtzXrfXZ33cAlroXFuhYhEtqCG/b4RAReWkRvHlOJbv34Bn//RUwGyGS488Xah\nlogfcC/Du9Uudnd4aMlGAMCK9cFFUVAlL/yaO3uz+O7/vIyb734RQDjJ9/Tn8dWfPov/vPe1sm0H\nGxF5H2Lgcp8e8QJu85OSl+CyKR8RWeeEqEjFDWiqItxXhq4G4rfE8RPWvLxvkYsabznWy+uwY9vx\n5zUPBT4D4Mufqi4AAqjB2vKERevS4eo0EU3xyds0NCTMIHl35ujLl1reXiDmXXRcEVsnpCRhTvXd\nxUNuH+pitejpz+P7v30Zjh20qDXTptYwczcTT2rqwv51drbD3dXC+qYbgOohldBx2xePRW1lDBk7\nA0MxAaIGLG9DYeSt+AI1gfkC/GQ3x4DjemhPt8EhDr551xMAgK5cD4jiwcumEDM1DGR4jbwjXLfC\n8jZTUtzft7x5XH989RjUaC1iO8cga4X6qwfXw8vSmv2/L1+BL972NJ5bSePv/YX+wHGinaylwXE9\nPLVlCfr1jZClCbjl3ZcpAIoL64h/iYXGa10rUIqubA8Gi0P4zSNv4KofP4Wt3ZlhNbdLX7YD2SJW\n71oLrWETVOba91gI47W1PfjRH1/Dd38d7o4W51IdvLTzFQxk/UUst4qD3+1hlf009DYa/y+zvAsS\nebthbnP/M0X1hFUnt9flHqF9xe6M561d9HmRPQHDLYzeLrRS5aJDEKW6GYFF2zBW86ad9J70MC3/\nsORH3tt+1cbesm0HGxF5H2KQyduSyLs53YA5jTNgKiwm66mIWzoqkqZI7DE0FfUVPkm7riLI29ky\nDsUNkwFIwiEy2Uj//0bvWwDoAsHfTv/fHLNMxHJlyzwRM6DrKrW8VQdEcaEpPB6uIcXc5isHX8fD\n6/+B3iJzKRcSKNgutnXmoCoqNnX3omh7Qq5UfAdhMWtOiEYBLhxk+k1c87MleHNzHwaGXMhtTxW9\n6LvMAboA4W5zVodOHJal7hIYGvMU6P6POGNnYaksN0Amb7Yw8VAU4QNSQt6KXqTKdVDgekR0U4Pq\nghCC7Sx5jeTSKBRdZLJ+jTx/6QvL20j6fd+lmDe3vFNGCiopDy2IVquOKch7xY4NAIA7HlwJQgj6\nCiWlSxqXf6WWt0tCrErNRUXSxMBQEUosCzVGX3KqomJd/4aApekRD//x4m345Yrf4cnXaDLQ+m0D\nw8a8S8l73dYB/PjV22GOXClkfVHiiXpdcq/L4C9xo/0N3Lf+TxhM+fFM/rIujXFvJstgtKwXf3O8\n2rUcAwU/L6DUba5YGcSP/Jv/5ap/H4ekKgO5MmRf8HZ6cpcuSDj25DbfuGMA/3X/soDrfTjyPpRy\nuEuH+HZi3p0lLU7DKjhMQyv77FBBRN6HGHi3MZd4Abd5OkGJQybvmKkFpBh1XUVrbYX4m3gaeofY\nKp9ocDtHiAYn/Bz+viGPQpjbHIA1eSn9n5KYNz1GBzQXRHVE85WYqSNp+eP86/q/+brmtonnV+7E\nt3/zIlxbw2CBveTYGF23xDvAPuelal2d0o+LaDSBTKUdxVy1ECBcriInn58vWlzPg6XSfXVDJu8M\nYiodu6gVB+ApLC9AsUXSXqAcDzSWbjHCdlxPiKcomgvXI9gyRInMy6XYi1AR96efuXeHbE7eKSGB\nC9XzrT5meadNSu6EoKScboiGDYgKkqXPhpyYtq2/F45EzqpnsnI6D6ahwnY9IfIiQzc8NFbHA+1n\nTxt5MmY3TAfgl8ABtAFNxsliTd86keCn6wrk05JhyAYA1m4rr4tWNV8NcHcQOvkshGEnt4ltnHxl\nK01uHSvvs7p3LX6+7C68YD8ItXon9Kb1cFwPHgsDFG0XetPGkkF6tIwS/n0EEBDuKRsv8fZIzm8n\nbC1n4e+N5X3lfzyOl1d34bkV5Tr85eMoP9ef1/zfbtX+9hV7Gnep5e0GLO/wY3mf8mSMl5mW73co\nl5lF5H2IgVvepJS848wFzcmbqIgZQfI2NEVYhAAATxUPKIXix39RalmHrDBD3OoyFATd5lTpTYei\nuVAU1gwFQNzUUBEP9oC2XZ6lreI1Fssjju5b1oxcbVt+80iWM3P9J9W0fz2uCkUliM/5BxQzD6J4\nAVc37bxGaDy9pDOb6xKYjFw1g373uv4NsD0HMS3BxufPnY0C8k4BLmxhvYfNkaHQc7oegcoFDVUX\nRdtFX44nDkpa2GyBMsAWXZt6ekA8BU7RD4koiivkTLmLO2kkqTEqhwYA5NwMVI/OPfdCFIlv+e0c\npLHCZHYU8svno4KwzHvdEfK8A5LL18vTcyUTCpK8xpuPQU+IOZRlcLn17xFPJPFpqor/emCZf97d\nSHgOZotlOR+q7ore67sDf4nzygoS8/MBuFXtegSKlYU17Um83hOMsfMXf3+BHpdVemGNewVG+5so\n2g4efHo9rvrxU1i5cRfUip7AsUrA8vYTqoazvPNOAf/+3K349crf7/aa9qTbrVZ24U3mPQPefsxb\nnvdk3F+YD7eYKP246Nr4+6YncNeqe3Y7vr2F63n42h1LcO9j4W11AZQ6YgJW9HCaCNt76LuxtoL+\n/sLc6283Uc0jBKs29r6jxi97i4i8DzHwhDW3JObNLW9D5a5XD5apo7bSJ0VdV8vIuxQyAQXIhoQ8\nCnsgd/m7EjFK3rL1yWO0MVNDMhaU7LSJLb6DJxEpngHD8jCiIeWXAknhQSK3PWX/5vNAR2MaJ8xs\nDYxXYdnqPMnMMjV/PjSnrDOb43rCbc47hf34lTsAUPUzOugY7E1UCtT2CljeQ12w3mBN4FwyTGbN\nP7NsOx54YpMYe9HxROa9fJ9UaIDioj9TxKadg+jNDQKOibVbB0RCG1RPlCxx8t64Nc+6z0neBcVD\nkRRgEDZ+fq3En1SejY5CEqZdA5PF8hXNgc403Tlx0fmk280YEd4WbnnHtLjwLshKev1539LnBNfd\nnwsmj9VtwD82/QuATzAXnzyOnstxxaLW3jyOnch5W+QtkvGYkp6iEtF5T7a89Za1UONZPLL1kcDx\nolQupMd63inikefpPX3hra1Q48GMZyhyzFsi72Es78c2P4nOXDde3PkqVvS8OSxp7l6m1IM14SX8\nz9q7sXGAVkS83WzzjTv9+/R2Er5K8wGyUmfEMG/NviKTc9DVl8fGnYPIOTlsGiwvS1R3Y3kPN/4u\nFs/mW12XAEYe8SMfwWObnsTTr2/H7//xVuixpXjxjU6ahf9WePjmQCAi70MMgZh3wG1OicVQfOvN\nKnWba6qwfADfsm6o9gl+WPIOURIjw7jNOUbUVfqHqwomtlcFysc42cRMXciJcnDxEz6GuKVjYmsD\nHGLjcxeNRnUFHWdBTiJ2Zbc3V3BTkU4YaKyOB65HJOUxy7syYYpriM96TKjQnbuQkoHjEphsMdLX\n9Dh6cr2iZGt2zZHivM6OUXD7a2ATGy/upPrvXBa11G0OACZbbK3fPuhnzGsOirYrVMrkudVVHVA9\nPL9qJ2761QtQDBq394h0P1TXLxdi5L1lexH5okv34d4JVmGQGeT3UQFxNbjwJ7UvT4nZKRiIW5oY\nLzQbpk7H1S/FefkCSTdcn7wNej5L8cm74PrWZU/Od3trjLxL1dOMjlV4YM1DcDwXhAAT26tw1GRa\nG19winCJC7evDs72MaKigTeMCUPfUAGPv7JVkJDcjc4cuRJqRXfA8pZDQNLFis5hgYQzBqphH1zA\naH3t0gLDldzmEnlb4eQtJ/r9v9fuxLLulaH7DVceV3RtaHV+WODxzc8AKPdqEELws9d/jf9dG1yo\n9PTnWRc8EiDm4cgvb7t4+LmNWL2ZlmxmbT+GvCvfF3rMvoCX5xUdF//96p34/gs/KRM7KvXWBGr3\nh1ns8NACfw4c1xNVDH9a81f86onnsKnTv++7s8K3sERBfr/fDUTkfYhBVcOzzbmVoTGZUUV1ETM0\n1ErkPba1UsiE0pPQ40c2paXPJHeYbJmXan5Lx7O9yzaPb6kVC4yBrI3KlIXjjmgX23lTkpipiZcc\nh6uxlxnLJm+pTeDoFkqSd6+6V4xHdpvLMWthgbsa0gkDDdWJwPWUlsNVJM3A9ei1NN5cnaIuccfz\nhCeBqDbuXf0AvcaqMWhPjQheuEv3W9e3EaYSA8nR+f33y+eXzRG3vAH4Cxtmeedt+sKvTib8/TUD\niurhjU19gOJC0VwQx4TrefA8iJg4J29uUaowaRmT7DYXjVmkBZur0wx5Bp4QZ+cNxC1dkLeiOeLe\n9hcly7tAnzdV84TbnBOXqcZD3ea9WXo8IUy6Vy+KEkYK/9nryXK3ugKTkXPe4wsxQ1wDNCcQ81ZT\nvbjnzT8LBb8f3PMq7n70TSxdxWK3hpSAVbMT1sQXBQm6HvGrGmRIuQWDdjl5F11byMNyD4ilxv3K\nDtUT3gWZvLmGQSl6MoMwSAIN8ToAwxPgcKpyj21+MqDBz/NKZPL9xV9XYe2urVjWvRKPbHxMfH7/\nk2tx9zPPIT7zceitawLqdjIxquke4bnY3p3BH59Yi+/9lraslWV4d2a7Qse4L8ixDP9C0cP6Aerp\nKG3y44sJ2VjbtwHLi0/AGE1DII7r4anXt+FP/1obOIaHRGQJYPkdEZu6BErCf/YzuxGl6WFW/HCS\nvgcCEXkfYjihbSEA4JSO42GZ5daAcFUzy7u1jr4oJnVUY1RzhbAeAQh3LHe5A74rm26XasI7R8Dp\nbIPT2eZvl9zQJJeCN1SJRNHvuGVqJs4/kVoZU0ZS17HIqAZQLPjkHbf0gIAMjAIjW7pPU20Ccxpn\nYFrdJKzr34AhbTs7h3TxcsyaK615GtIJk1qBsopcSUb9rPH1qEun/HMxxboYUzVzXSL01wFfxtXQ\nDJhG8GfCSSTjZGEp/uKptT6FUpiaLITj66sXbQ95FhNorvGTDGO6IVnOvshMNu/QlzBbwMiWNyGA\n6rG+6l65d0K27ImniVp2wE+kKuboPbKYWA50h1U7EAzZQ1DsONA5GvYW1pVMc0SiD0+aMxCDxa5X\nJm/umvcGqJiOVrsNg3JrTql0atsQJVtVVYVlXfAYKTAvCl3EObAMSU9/1HI8ufVZPLLhn3A9V5RM\n9bA2nUqImI3sNlfCyrcUV7yMB/dgeXMPSEz1PUCK6olqD368O1gFNZ4RLm0ZWTuHQk7DB0efAcDv\nA5ArOPjt31eL/Rw3PKntjV1BFy8PXcge7KGcjd89/1TZsX99diNyBiVEo3VtoFc5J38lNgRr0guw\nJtM6/lK1uqyUUf+/ax8WvyEZL+54Bc9L/elLUXSL5fr/kuXN8asVv8OPX75dhMe4Vfzwhn/ihy//\nP2zxVkKv2waoDhzPw6/+7w08tGRjwAshpH+55e2RMg+kKpP3btrfdrPnbLgGPwcCEXkfYphcOwE/\nOf67mNVwBCyj/PYYnGBUDzFTQ1XKwq1XHI0vf5hm+fK4LQBBvgH3ouwelC1vosHeMBVeVs5WD24v\nrJyPKUnfhWxqBj588nh89zPzMGMctRZiElnlmfEbs6jbvLBsIYobJ4rtinR+njRycvvx9OsUjyXE\nlMfdrSnPCsubeNTytgzNT3aDH1fk1xC3NMwd7y88eDcxUzOggCa1aFKHXMI8Dbyvugw5NGCqscC2\ns0afiqq+2eLvQHvOgHyqi6JXBCEKWup8z0jcNH0vCCccx0Qmb9MsbE+lMe+Cr3QH1wAhCLjNdU3y\nTngl91+aJy7/6hQMxE0NMS3OzmvT5iu6DZe4UPKViO+a5hOo6kiWN53LJ1/y+8bLMW+e8ObsGAli\nGzDa30Rfwbcqq6sly3Dlb2BNfRquPgRNVaEqCoqk3PImqitCSfLcPrrxMVz1xNegN9FSL9cjwoPh\nZYOLq0DCWgi5Q/WEBSqTN/GYfKtr+6ED9jwljARkHf3BnA3Xc7FhYBNMLyUWMLe8eBs2D/ou7oJt\nU8liR0c2xyxCRn6PLN0kyc8SbM9tx5cfvwlPrH8hMNwEy81wulqhQRM6/4E4t+KhR6WJX2rp619a\n5PXYdBHVP1TwNQXYgpiXBZaSmWx5bx7ahqU7yrPOf7Xy9/jNyj+Iv4u2G9B8v/Wl/8Y1T92EXfle\n/HrFH7Az0yme9VIZ1NV9a6HX00UQJ+A1fesD+6iJwYDbnH+XJ2nYc0+G63riPnLIev6lynUyIvKO\nAADQVPYjUspdedzyVlRPuNJrKmJCwjBgeTOrNpDMEaKqxlGZNANxW0teCDA0VfrkbmoGFEVBY7Xv\n9pXJm7+EYqbGXKBKILNazlavTtPjWpK+znhpkhx3hauJId8t7lLL2zTUACnpsWLgHKauBRc2jNhM\nzYSmqXC8oOXNyVtX9fLYqpQ3YKlWYNOpI09ERW6cv13zr5d7PbTa7XijawOKjg14KtobZcvbpN4F\nkDLL2/OISNrjpMFnvQAAIABJREFULwlFt0EcAx4jb3gaFIVlC3P3eWkSIRfaAc1GB2huQNzS/fun\nOYiZmp87UIhTS5yoILaBIjJSzLsI4mpYsqwbdz9MXZM524/r8nixl6mE09kORSEYcCh5X/Ghqaiu\nCU6vmhhCwaRuV8NQy8ibXoODdFJ+PoOWqNEuNVlhiwsu7AJQFz63vF3J8k5qKcQd2kRGUT3YbJ4H\n7SHEtBgaN58HZ/toAIBDbL8Gmn1Hykj4vyvFQ6Ho4q3e9cg5eaSdVnj9dWIMXUO+KtiWHuZKdw0M\nZei4RJMdRhp6y1rE5vwdD+/6LYrI4c8rngxc85Cdpde1fip01RALKNntrbe+BcegnhAufyyseOn3\ns9T5E17dvAFf+q9nxCLHigWJqbQ3Oq9l93rpb3ht34bAdtlb4HgObNfG13/+HP7th/8Sn29l5ZM3\nPPtdvLDzZfxr6xLkmOVdCHNJMw8cH2NNrDqwWUkMBkrF1vZuxuceuwbLu94Qn8ltb6EHr4mrJAI0\ncY4QgnvefACvdvnhCdvx0McSE4frQX8gEJH3IYzm2gROmNWKL15whPgskE0eAqOEcDVVCcgbkuEs\nbwC1lbEAoafiQWICgNYaP0lNjudyWJLb3M821/06TEk0hYu4AEBVih4XsFRLM+Cl97OaYGVWsuWt\ny+QddJsbuio0vGWYqgFNU6h2t7SY4Mk3pmqUkbec9CeTsxibNN+xEMtbjWXxt/7f04Q1TxPudtNQ\nYeq+Z0V0RXNMZPIOtSCY5e1fqA04BmzH893mAGJxlLnNP3bqBEHufFvOy9JnytMRs3RhvampPjqn\njLy9giU8EKQYR8YblFzGRX9O2Hf15XwrLONkaEzZMcR+3EqLWRo8zd83zix/T6X3z9BUOGD3UnwH\n/d5kXAqTlHTVcwerpG1sIWeb8DIVYpz8he95HqAX4RXi+FjHlTAddqzqsg531PJOm0nYDoSHpugV\nxQKAex8qrKT4XXHPx7Iu6vKOF5vhDVXD3joGAPDcm742+NZeupghjo6BQWZpMsubX6Wa3hUoAyxk\ng7+/jJ1hc6RAgyFCF778bT/05vVAMQEvkxZiQaVqfRxLNvtlc+NHVOH8U/zcDzW1C/IP8ub/eQld\ng9TFbO8YgYQex/qSJjy25xPj7ct+gy8/eQN2ubQpztbuTKjVammmkKQNI0au9Oc4dCxyxjtA3d6y\nbsCDq/8BAPjjW38Wn/FjHdcLvEMAQElI5J230Vvow5Nbl+Dny+4Sn+8azIuZiCzvCACoxfzRRRNw\nxBh/tT4hdQTcwSoU3pwdekwpucdMLaiQJJM3CZKZkDdlkBOpONpqq6T9yxcSlaZvRfKXWMyULT//\n/LpE3tzy1lTNJ9mSxUVx3RHwWMIUb0nKY96moQlXOOCX9nDi0jU1IBwiX4OuKtjUOYTfPbpOfM7L\no/RQ8vYXKLy0SoYWIG95MVOSw6C6UIkuLNiKhCnlNDjCCiCOQd3mhJM3LwVzoai0tj5XcFDgbnMA\nlim7zTUoACxDCyTNAUCB5JDQaC5CwtIRN+j86rU70KmsgWKypKd8TMwDKcThEgdEp5nJil70xXDY\ngi3LuscRQtDv7PLbz7Lvz7t0e6ezCTuTVNr0kxMvw1ktF9Ahs/71pqHCUYKVA2JRKB5PAhhFjEx3\n4OjmuXRqpC588iKosHIezGKNyDsAaLKiYhQB2/QXSACgUMvbIx6G7AzSZgq27Sc2Op4jkTf9jsp4\nWlLCo9v6cixbv0jnmeTpwAekBc6OPhZGcA08vIS6yGWyo99B5X7zrx1L50hx8M+XttA+5ZkiuocG\nxBzpii5i5tzw1iq7oShAYeN4EMeEogDZYkHEkkvj/tttP8FLU5WA0Iw1+XmYE18AJ/A1W/rx6rrt\nbJ4NjKrsQHd+FwaKfqWCHMte2fMmrftnv+MbfrEUP/uLb81yFN2i0DRwPSJCFv7AWNWJ68sJK0RF\n7oVF9JqsbMBtvnEbnfNdhV6YE16A3voWc6F71I3O3iFpvQLENpnbnC0M8g5eWeu3C/WIhy2dQ3h+\nVac/3ihhLcJwiGtxFFfNg9dfH7q9lLwtUyuxvOWENXr7501uxPc/O5+2/pMs7/GtJf5MMMuCn1sr\nt7wbEv5Cg1tIPGv5e5+Zh7PmjQ0dK7e8AYiM5dIMYJJP4SOTzg1+oWR5u92tZePxX8R+9q0MQzWk\nemH/+3gs2ND0snri0fX+NRoh5C3PtxJI6C8JA6gudMUQZXQVSRNtKRqX16q6AuSbZZa3r89OAuSe\nLTjCbQ4Ahkl8kvc0aBpLAJOS5gACm+TAeSNmakjr/uIrhz5R1uQWLDEPXoH1ZlcGac9yzRPhEL5Y\n4q1fu3O7YKMgLF5ueXsKJYo3B/3yqE1bbdz10Dq2ncdXM7Br3gKgCNLjC4BYjL2U9SIUBUjoSVw0\n4TwoTgyKbqOphu4vXP/FGEBUmAodP7f+O6uepIsgT4Preb4YDot5d2W74REP9fE62K4nFp02sSWl\nO2Z5mwkpt4Fu4yEEl1Vf8DnK2QUUbRe/fGgVVm1hjXpcXWyX8wb4dRLHFGI7ikoT2VZv7sMdf10B\nWymI+VWhS25ztsBgdegkmxZz2DkwhKLtQW9eC60mqKrW7/o1y6qqBKRhAdAOaZK1nnPZ78s10Jyg\nZX49OT80kA35/cmu+tfWlau6FdyicJtD8VhIyQfPc+GLqEwxA89mioKuCkVzgqV10m9Qq+wRTXgc\nhzDLm97Hj7Z/Bl6mIuClsl0Pv3/CL9+7c/n/4DtLbsMDT/niMZHlHWFYqHu4Y2aJNWwZWlD3N1Aq\nRh/kptoE6qvi0FQ1QO6TO2rF/08fU4svXzjdj8ejNL5OURvzCZ94Kl08MJd5Q3UCR030CVa4iAGk\nErIrmrfwLL/YMXV+TJx386pImNA1Bc7WsdQqIeUxfuIRVFoVpaeDqRmi5EgJqXU3VCMgvfjZD07B\n5R+YIf7Ww8hb2t+TpEcntNUFd9RtqIqOuGR5z2+eCxAFesNmP6Pe1ZHJ2zR2yaw6a+ozfptXx0Au\nHyTvbGIjI2h6nzVNoeTLXtqJlEvdpooHz6afxS0dST2Jwlv0+lzFFpa3V4gL0RauVpbxBkUSk8hl\nYN+/dscu/Owvy/HEm5ScSYaFW3jZGrdwTN+789yrg4J8XOY2R7IXUF2MVuaCFBOB73hsgJYUcosx\nriawoycLt6hDt1xRiREgb0BkxOftIjziIR9jFmM+IfIKAFoOt2pjL/64lNbzt6aaqQyqwlu32uVu\n81hSkD/XyOdVBbativtJPy/gsZe34ull27F5Fy2Rq02m/aQ/j7vNFYgcCFsqeWT3dzBrY8122mKX\ne0Dyeep2J4SAe43VWAbEU0AKcXGN37l7KdZvH4AxgvUzcHTkXliEpNMIBzZ4GZ+mKgErmkP+zdhM\nuY84hig5zEreroxdImID3+1N57A8abDgFkTCGkqSyVJGUhCrIyzvrK+q6OqA7gT7juvhSWe268F1\nCfNuqMgXJE8Zu8ai7QbG8GrXcmjpXvEbEfu8S4jI+z2GcLe5VPIVEvMOWJYSudek/Bfr/KlNmDra\nJ3MAAUEYDpnc4WkBlzk9p2+5VyV88RiZ8MQChBGVSGarS6IuLnkDPA26ptDEKkUBoIAUEtAhuarZ\nS8ojBIs6TkBV/wwa72OQyZk37pARqJsHVZKrkhYBCb085i27zT2pfj5VojKnKABcFQ3VCUwbXYu5\nExtQHauC5VZBiQ8FMuppqZhfh6omhkQmLHENDOVtFGzfbd6XXI7p09l99VToqsK6zrH5GPMMVNZb\nmj8TBduFoijwhmhoxEYOipmHQhTAlmLezPLut/tgxIPkzc/Vn83h+VWd+PsKSt6lljePLTpMMCb/\n+kLs7CmIuHavthHfff5HUHU6B5br51qIlynJQzFzIt5tKQlk8g6Ia8BVCjBYtUYpefNcjbybD1iD\nzrYxrByPC9FQ8n19K81gbk01U8ubu80JJe+EpQuXdtKK+d4PI0jeXPdAdPDzCsiKen3672lzx4rf\nnS2XWqksROKYoGI7qng+sgUnQJwAkMl6ICBwPMePeceyIIUEAFXyDri49wnfclR0ByAqFM8MzLWm\nKqHlcmD3UYkPQq3sogtqT4NC6Dhkb1fGLre8lVjWJz+jnFjzbkGUivE5cnc14gOpy2CqlvjMcT04\nnoO8mxeqisTVoaiOmGN6fSXfwd3ujkcXAJoNuAZts+zySgJ/n7LjpTkChkmqO0CIyPsww5566IZa\n3oGYd3mdNycbz/MCCWvyQqBUfrB0eygUglhJrbpsrVt6+PG8QQh/iR49tQkfP20irr5oBkzNpCtu\nNv50wixrSiCTN38RVqUs6KqO6sKkQHtUUzVEOdCYmjZ888jrMK5q9LDXaGhqoJZ9XHMdPnDkCNx0\n2Vz/slUFhVVzoearsKDZF27hgh4yKhMJ6JqKL104HfOnUq+CiQR9YRh+0h0tFSOB8h7enAWOgf6h\nYHY9APSxzm10kaMGLG8A0OtYwhT7bOG0ZurZYQRAyTEPnSQAKNA1BQumNWFEFQ3Z7Mr3wUqweHKJ\n5a0YBVimImWrJ6Brip/YptlQFF+qlQghGVVoxW8Z2ibmwHN8qV23zw8ZKUZRWN4WErSch32HbrJY\nrsVkMJnHIGXRf9/Y0oVfPkL7NDudbSDFOAtNcPJmI0pQi7Ml2URj3sxtTsnbRdzSka70UBmjrV1L\nyZ+7r4u8RxD7DRb1XRhwugHFhcZEgyqsBGrScRAiuc0VOW4vJe0x0ti4Y1C4r0lJ7kHRsyl560W6\nwGChB+Fh01ykkiE04AYXWSqzvHUvjuKGSbC3jaLbmSWqN9MFDsnR3vSKS8chk3e2pH4bAPSGLYjN\noNnmhuWTYFKpggIFBafot2FlY/EKcShOnC7CJLc5j6kHLG/NoUTMUGrdK6oHKJS4HZewcj0ahvIX\nOPQ7MnlbkLfT3QJnRzubA7o9bmmR5R1heBT3QN56iaVoGcGENeLJ2xXpv1wmskSqkyGsLWCY5Q0A\nY6voD5vYFsa0BF3VMtEamoFvf+JI3HrF0SXnpeTI5V1jpoZjp7eIuHiVVcmuRRMNW2QYEnlfcfZ0\nfP7caRjTSo/RNZW9YPh1aSJO1VSTQGOqBhWmbJkH57M0/p00E/jwiePQ3ugfo6kKvMFaJDedgOqY\n/7nc7IGjpabc2rcU+oJVOem4GnIFF7miCy3mZ1VziyWQxyBZ+l051vCFuc0NPRgWEYlVnobPnD0F\nNRUxen+IBuJqKHg5KLotXsS6ruITZ0zGNefT+9WT3wU9zsnbEucCAK1yF5qmv+W77l0dLbVJsVDQ\n67Yj1rLJF3MJkZYFAM9gCnBF+twkYjq83iZUDbIKDL0o9NKTajWyeceP+zJLTjGZNeZpuOXf5iPN\nyHv11h68vtFPsgJo8ppX4vZWrBwsNYZ7/rYJBP4C1CU0YU2zCsg4Q+iobGGeJhWEKFA1Rt6eDVM1\nYLOsZrHAqejBC7gPWsNmaJXsGow4aiuo0EvOLhey4fFuIkkFr9vZi9gUKpzCFy5y3NzziEgMEwtX\nISTjoChJ2Y6qYhnlJeENVSXoLw4grqThdnb4izUtaBUX3jgyMA5ZMjUjZYIHcmMAQPEEeRc3TMIM\n71xYmomiWxAxb1GD7eooFF0YqinKHh2X+Cp2IrFRh6J5yBSkeWTkrQ42wB1gZWWqC9vxqKdDs0Fc\ng3lwuOVNv38wa4v5cLvaUB1jybvsGY9behTzjjA89mR5lwovWKaGie30ITthVmvoS5KngHgEQQlR\nibiUEPIu7fTEccX0T6Bi2wkgmUqcefTIYcdqqDra6lOoqYiVfS6j1Hr33dYkKNTBj5dUz2pSScwc\n71tqhqYG6n0BX7iBZ30nDD9cUGZ5l2Sex8JKxbgbnhDELR1nzO/AvMmNWDituWxfSyuPmXPyVrhl\nzd2sRRfElPpCM3KXQx2Vdf6LWGQrexp0VaVubzlhx2KuVtfXntfZfSaOgSFniMqzshc5d5vH9TgS\nehy7cr2iJI+/zD+4YIw4f6eyRri947qJz35oqug0BwBoXYmCW4ACJZDMJ5frOBq1evM5Rt5snKpD\nCXj2tCTMup3wCnGkvUYqYcleuq/iL4DqUPJm46urjPueE83P6OcvfM8jovWqxvu6aw5yWQVLWJtM\ngy1aHTh0MZ2gNdod6XamSgfAU6FoLJud2DA1U2RNBxfQ/iKNz21N2gI8Ddtz23HHsrsgMvqlccLT\nxMKoM+tnO8ulcIBP3rwlqli4Sm5z3oa3Wm3CN46/KuCh4ffC1bLwiIdxDS04YVYrxrXUse22P5eA\nOG57Fx1vgLyZZXzVjE/jqhmfDswBvRd8gRLDUMaFpZk0YU3EvNn8cfJWTJFQ5rgeduWpp8kX86H/\nDhX8MSi6DS+TRmbVLH8Bwo7f5qyHogDeUGXAbS5yC3K2mA/iGJg7voWek2kimHpkeUfYDfbU67fU\nhZyMGWitT+G2Lx6DxaeMLy9XAkRrPyrmIFnGEonK33v6yJNRH68NxH5lWJqJq886ATd8bI7I+JXB\nm6+UeglKt/NyH8sILjgqmeWt6DbSyZBac0XOXA9+h6YpActbBpf7TOp+LH5PlndY9yQeo+eqcecd\nNwafPnsKmmuTqDQqA/uGldvFWekWfzEeN82vr1WkF7/O483SgqxdnVZ2PuL6lrdcIy7I3/W158e3\nV2H+lEY0pqtEaRBPaNOlhUtNrBo9+V7U1zOyZy/CU+YEdeCJ4oB4KmaOa0RTTQLfuuyowPa8W4Cl\nmaiuKF8EAUBRoyV7OUbeosENK9dzk53wFAfurkbs7M0hm7eF29RGHlrtdroAkcSBYixPQdHcMnf0\n8nW70DfAXMUa70jmBBZIwuOkuFTVLk5Jo6OizV9oen5M2iU2DCk8M5yXAQDqE3WoqYgJ1/1rXcuR\nJf2wJlBJUd/y1oU7l3sv7O0j4Q0wi5aR8zPbn4dLiBAb4QtX/syYY19FhtDx1+ktSFspGLoKj7e5\nZZamrdHjm5J1+OiiCWirpgaBYmVhjnsZWsUudk56n555lWaqb+rpFdfG3ea5IR1JI/heUHQbnsEs\nZ9tAf6YIS7NQCIl5wzGQLzq+qJLqoug6uH/NX+k1MhU7fo0i1q7QOm6xuJcaBdmOh60OFW5xu1tD\nLe+hrB2o8EjH6Hvig8eOwHc/PQ+WoUUx7wjDY97kJswaX4+vLZ4Vup1nVNdZ9Zg2uhZnL6Qu7GSM\nJWaFkTezvUvbBcqiJvKmM0Yvwk3zrw0mp5WgtjKGUc3h5M7JebiYOd8u9MdLkt7SJn0BKbqDdLyc\nvGXLu1RIxtBUv+SoBJwYZMtbLyFXTmAfnXQhxlaNwsiKkqYlAM6Y34HT5rXjU2dNLtv2hWlXBWr0\nwzL2E5rk1lc0jGzyCb+m+xhfI54n+kj3tEFvxw+O/ffgCT1VinlL99RgbnfPz3jXVBWfOmsK6lL+\nvXOKLAFLWrjUxqphezZ67R7qvuS92y3//GkjhYq0CsXTce6xNI9A04KLy135PliaiY+dOhFnzO8o\nmwvCwgCdPVRHnN8jTmI7MszqtC1s685Qy1tWA+WhBdvCSbOobn/CYG1Nx7wu+otzwn9pdZcIJ2ga\nK8nTnMACSSgPcnI26QJjRLoVlsmS+jwNHmhfe9uz0dMnJToNoxxYvXURLM2kSoeyVCk2+vMhW95C\n598RcyD2Y8f/c9OT6NPXQIlTD0ap5a1oHvQxNJuee5Fk8halWCo9vi5OibE6Sc9jtKyHVs3ugfQc\ncm8N11bo7suhM0sJ/Sf3vFn221fTPSC1G0CKFrxsBdZvH0BPn4OcUxAiLVzbgdgx5G0XGgwxxgIZ\nQme2G9PrpooWvdzyzhTZ74Qt1OIaj/v7mgeO6yGDXpCiBZJL0wx1fs/ZImkoZ0uuewOVTGggFgcq\nUxZMQ0XRdvdoYO0vhJs+EQ5ZWKaGK88tt644UkYS35p/HSrMVHhMOqS1J3/Zlbb+k6340pZ77wS+\n5R1O/pogb/riLiXvuJThHeY214gpvqd0gUHJV8XE+Cx0NFQFtnELf3duc+5Wntc8B/Oa54SOP27p\nuOD4saHbUrFYwPIPu0dJPQk4/vbqtH+9llcJe/0UWJOfB1GZFSBZhZahIaZbSOoJP8bo6dBUBTFL\ng9vTgqJuw+zw5SHh6mVd32TLyCkyYpeItyZO44V9hX5U6JXg7RsURUFh9UxY41+B7dlIWAZqrKQI\njWglnouck0M6UYdpo2sxbXQtHlqyEYWVR6FuyhoMEhazJ4Bjq+hoTQjPByfvXqaRbqoxbO/JImbq\nouc44MtbLpzSjounUNnadExanNWzpL1Aq1z6Ha45CKj1rCpAmmON11mzeD57oSeNhL/wJCo8uEhY\nOlzVLbG2gwsYTt5xRp7phBH4neaJn+XtDXBi8suYeLlVIJ9B+v+COgjVGoKXjwsPguyB4z9z/rsy\nNBW9/QRWo+/9Kap0DPUsVl2TKM/VCHw/I/8iyeP1tT348f8+g9gR6+EO1ACuIQiZQ6vugqIAxS3j\nAU9HvujCLChQzaLoC6ym+kEIdWsXii7i4HF5FzZT4RvoKxeEGsxnAZjQ0rS6okKvQhcQ0DywHQ8O\n8ZUCs3lb/K54eCJrboNVxWrfPRWVcbqIzjssVBUbgDbiDeSKp5XNzYFAZHm/B1EXrxk2mexblx9V\n9hmnZbIbgi61yt8JRFx+mFOqSnncXkZCcmvL7U55rJlnm4dZtfzlf3TNSTh7zKmh35PYjdv8ncLU\ntYALN8z7kNb9a7JUU7jhAZoMN29C0NqX3eomW4BUxZi1ThSAKNBUBcmYgfqqONydI6ES2UrSAhYz\nAD+jHxAvYpng5Xr+ilgSs8bXi0XloglzUUmakXcLyDm5gJiPripCHpQjVhL394aqMdM8xf/A1QEo\naK5Jipp8txict7pUBTp7cxjMFuF2t+KktuMB+PKWDekKUXWRNMt/G7LkbYJtH4qvg9FOFzky2QkJ\nYJ6Mp9iIaVbwuXV05L0cYpVDVMSmpLJDBifvGHvuUnED8o8jD2r1FlbPFORbEaf7ajXbpQ575RoO\nAPWsKUYRhCWr3fjxuThnQfnikqvrmYYq7jl3mxdUujyrZ5Z3Q0U5eZeqNxJPRc7JY9naHiEA43bS\nZ3XZup7gHHDPgOwV83QoCi+1I1CT/SC5FFRCyZ9b3takpbAVapWv3iDVkrt8AcF6rTdSQZZRFvOI\nSZZ10XHhoCg8BnLuBPds8GeBjRhxg3fQo+SdSa6F0bwBm3qD7UoPFCLyfp9hREMKN867BjcvuEF8\nxt08Lvt3XP+5uOGoqwPH7U/y5pa1S8KTO9QSy7s05t2UpOpN9bHaQO05fzlzyzuMGPnLP6xjG/+e\nZIjlffHJ4zB+RFVACW5fQL9fETHN/kJ5b2ceFgAASw+St6oquPC4oDu+QrIkeQ9s/pKloQefMK5f\nPBvzpjQibUrk7Oplcyxn3E8f2YxPnDEJx83wBXZ4xj9A5+vKc6dhFksMvPCEsRjZQL8/5+QDSXma\npsLZOg7FdVP9awxJ2otp/vg4cTbX+Za36ypI6v51N1VVwiME67cPQFVUnNi+AIBfTiff0zE1I/zs\neAauugbQOefQG1g3L4mYDFWjiyJmeXuqLeLoHPa2MSDwUGxaRj/gmvNmubdJMQsgRKEd5cDIW/N/\nG0VlqGwMHfXU82GOXAU11ReYJ/n7AKDIiJfY1LXb0ZTG+NagZgPgaxYQ4ru9eYJWERkYqiEWdfXp\n8pBYaSIeHAMFL49ETIfCeoDzkM+ytUHyVpmSn+w14Za8Yuap7oHmwstUwjI1arm7vuVcTG1mx0jN\nhQIxawIt3Y8RqTbUWLWB8Sqai6ydp78VtmjpzxQDx1tSCaCzg4Z3+D1/bPNTeHzz037mPQn3KO5v\nROT9PkRDog6VVvnKmbvGLSWFpmRDYFtIXtY+QxXkHX5SlcfaWcy71FoZXdmBzx7xcXxp9hXB47ik\nNCPv0pp3gFs1VIq0FMJtHmJ5nzJnBK67ZFawZn4foCgKrr5oBmbWURWz0sQdAEiYMVHrbGkmkjFd\nkLKmBUkLAJKmFONn19DMFjgcfOlVmbLw6bOmoDImddjytLJEx3qplOeoiW1YMK1ZzB2AwPOTCLsG\naQ5ly5vfS/klHUbeimv6nhNO3jVJcbzreUhLY2itoZ6ATN5BIqajwkoHqiHkMVZYaehvLkL+9YXi\nM/k+xIzdh5scj8BQLKiJQSjxAXhKOXl7fY1I6Wm4Zl/g+NLKCQFXF2JKybgRVC5TWaxXImdu9QF+\nXH+4RLiiTscwsq4ON1xKQz1hiZIJk96zwWxRsjqZ2xw5VJgp8ZzIz5x8DYCfsEk8DXllANvct0SN\nNo9Db9hRrtYG0JJD/qyLOTviaRH+IPkkYqaGfNHBUM6fI0/3NQ9Kx2N0rGJzRFATr0KMe5mE5e0K\nsR5O/rmCFPPWHDom1QOxTdibJtFxSc/tfW89KMJYils+twcCEXm/j/GlC6ejrT6JY46gJQ+cvGWC\nGsvqoxtq4uUn2EfwOHRYpjbgW+Ya++2kQmq5p9VNLluAcLe5RuiPKszyPml2G7568UyMaPDJ6+On\nTUR7QwpjWqk1EbS8939ayOSRNbjsiPNxycTzsajjxLLtluRaNzUqQiMatygKNFULvDiC5E3nroy8\nSxwnSSNoeZeiPu6Tt0zEHBVSA5rQBUiAvP2xcs+HHDoI08h3XCI8LPwl3taQFDK6iZiBasn676j3\n3fiJmA5VUZHQ/WssXfAYugqST6Hw5mwU101FXaU/3ngIecvEWSi6mKgfDUVzoTdtgIci4iELkLRR\nCaIEO7uVhif880slmpoaIG9P9ZOkOJRA1QDvXS/FsaUua45OiW/e+A7RwS6szDNlMNlbqVaeyt8S\n2MghLXljShd79BqY5rzJa8jpta8k//TVANl5t3aXS6USxwCIhoZqdi8kmeOWDno9Zx45HjGTajP0\nZf2ySRKKQnq/AAAbb0lEQVQbCJxfHo+i29CbN7BrTIjx8Xtijl6OdblV9CBXFzkAckKbodN7Egif\nlNxzy6I/suaaYEXJgUJE3u9jTBtdi29/4ihhhXLXuKym9qULp+Nri2dhTMv+eyC55T0cefPt6YSO\nb19+JCrfpqv6iDHUHdZWzVyKIdZFzNQxqaM68PI5dnoLbrr8SBh6iOUdco79AV3VcXTLkaFjNAzV\n713NSJeTN19Y8aoCUzUQtyTVOmF5S33R6ZkCf6UMyfIOJW/frZowysm7UnqRlxJj6THyNXLrkkus\nAggo1vE6esfxfCEPRkS1FTGcMa8DC6Y14XPnTEW15Sccjm32a/m5cE9ausbSaxC1+P31cLvbxLMD\nhJO3vMAp2C7aY+MB0HI7opAyyxsAKqQmL9yKa2sIL1NEqbWmlv825PvUW/Sbhvglf/52N6RxkRyO\naU+34YNjTgsQZMqS480avEwaWsUuaPVbQBQvcDwAjKroQEqpFp2+xjbV4rR57SK0YW8eL/bV0n1C\nOnU48FBGQzUTKUr6IaW8QRu3jKqvQ8LSkc07yG0ZAQyy6xS96/05mDLC9x7yTPWkkfS9H9K+y7LP\n0jE4BkawBQ4kt7mha9QL4egY1ZzGly+cXrbodPUcTM1EOvHOQmtvFxF5RxDglrfspo5bOsa1VQ13\nyD5B3UPMW1jm8IZ/2YXg46dNxFXnHYGFU6hs4R7lW4eBoRnCZbuv53gnMHVNvMy5vCTPOOfZ2jUx\npg6lKIhJ8WruNi9VsCq1vNvSkmBMyAtVJtQwy1te1ISRe3wYy1tkrEtWolySyMvRbMcT18jjoYqi\nIBEz8IkzJqO5NonqmL+gbK3zn9EjxtJrr5LIPVGywCgNxUwf689XmEuYuJqwVfNFF0nTAnE1EVMP\nI++URN7cyjv32NE4//gxZfsSV99ziZFENos6ThL/z/UQZOudZCqRe/HkQGy/QiJfRVGwqOMEJGzf\nQ5OWyRsKiutoAqLesLnseAC4es7nME+7UCwARjdX4YLjx4pyPrenFflXj5PGT3uNDwfujeGWt73N\nn6ch0Bh52kxhVEsFXI+gp9dD88CxwXOwRe8lp4zHF844DpW5CQAANUkt85SRRJznHYQtJFzd98rx\nksGqbrg1a2jioWvgklMmYOro2jLvQ09uV6gH5kAhIu8IAsfPpAlJs8aHtxvdXxhTORJAmHVIMb1u\nCh1P24K9Om/M1DFjXB3SZhKWZgaSrvYWPEZ6INzme4IpWd5claqmgvc7py8MTmxFtwhLiqNazHug\nqzqumvFptPfRspVSWhhd6ddUf2NxeQWCjHgIecsItbyHiXnLXp20Qq1dXv8L+Ja37bpoSNDnkBBg\nzsRgDgYAVPMFDILiOVzJri3V4o+x1PKWyPu4GS1oqfWvIWUl8K3510IfkhY4ro6PnEItyUVzR9De\n6LYpuqrFJaW9tnrqrm9IBpvoANSDcvq8Dnz9yC/jxBHHBM4vo3YoqONAXE2QCQBMrh+Lr5TkfJSF\nPzxdlNQBCP09aNLzLSc+AsBXPngs4GqC+NIhxxMoYlw8h+Wy0yYKWWRSjAnLXHZpA0DlllNwztgz\n/HOxPIiKBPME9jYF8hIA6k2Z2O7f95baCmiuf2/5d+SLDlRFRYfH+ruzkreUZHmHClY5BkYIqWMF\nbj99RovJbfQjVy+rfvHnItwDc6AQ1XlHEDh9XgcWTmt+227qfcWF4z+ICdVjMatxeuj2SbXj8b2F\n3wyWK+0FNFXDNXM+H3AN7y0Sehx9hf6DYnkbuubXm7JabRHzZqRTKxGX/DIxpSz6CTVjYbkZAD1l\n7D0i5WeOj24O96x8ceZnsHFwS5m7tBRhZYmyNR6WkAYA7eZErCg8g+aUb/0J8nY8zG2cgXV9GxDL\njMTZx00qO77GCo77psvmYiBbFHM1qmoEsCV8DHweZ42vx8dOnRgcu6WjLl6LOSNH47lupn3u6Zg/\npQknzaZCL6+s7mJSpdQzInsqvrZ4Nrr6cuhVN4nvT5oWvvrJo0TYoCXVhIUtR+GxzU/R87s6IHHC\nCGUatrxYg/icf9APSohGVYKJi6RE2lh87vj3Jox8PdZ61elpQoKFX266bC5Wb+7D5JE1UFdUwovv\nYseXPwfHTm/GP5dyOWBK3o01CXz90jlYuWEX/vPe10DsGHXtl1xDtVGHk9tnYkzlKNz94t+xsZMu\ndqaMqsH0TbUY0ZjCX5esByGK8C6kzBTGj/AXXifPbsPmt6rQyYVYuDgMlzw2LZCiJRZZKTOJmDK8\n5U1cHUdPbcIf/klbpBbfnIPYnL/BNQbE9pgxvOt/uGf9QCCyvCMIKIpywIkboC7Go5pn79aqTUuZ\nrfuCpmQjUua+kT9A3c5pI1VWc/5uwNJV1roRqGSJYaUxb9EUAcGyt9KSLz6DpIS9Dc3AhOqxGF9V\n7sLlGFc9Bie3Hzfsdo6w8IdcBx6WkAYAU5Nz8NkjPo6zR/v19kdNpkQ+sb0auqrjkknn47w5c0Q+\nggzZbQ4A7Y1pTB3lx65HVUqysiXPkio66ZW7qrkVP76+zf/Q1f0sZQCmqQUy5mXhoLilo70xHcgb\nOG56G1rqgs9jXPYGuDpOnOV/XzpBVdZ4YlmYlRhIOvR8adLW+iTOOYY1B5LqpsPCHzWDM+H2NqCt\nOE/McXtjGiczmVut6Lv+wyz3uso4KhL02r2S52DyyBpceMJYv1lKyTV091PCHVXZjqnG8aLne1NN\nAl+4YDqV2iWqyI8wVB2WZiIVN/CJMybhyxdOR3tjGpMa/C6AHz91Mlrrk2KR1TdUCCRHUstbCx0P\nAMAxkIobuP6js3Hy7DYACohtwVN5A52g5X3JxPOFpxAID58cKESWd4QIIfjY5ItQcIvvaAGxr9B1\nFW5nG2yjgCvPOx8AUJP21a+AYDKWXH5klpK3SJ0t/56rZn66/MO9wHnjzsIDax7CxJpxZdsaE37o\nZTjytkwd0+qCNevnHDMacyc2BKoBhgOPaZdm1nPwpD7e31lGW30SmzuHUFc1/MtWvi7iagGXv6Vr\nAZd02Eu7OdmI5mQjtmd2oqmy3LshW84nzhyB8SP8fToa0wAUmEocBZIVVutpR7ULyeOEEYcChS7M\nJCI6YWYrTpzVhhNnt+Hz/1WEmuqDSozQZ/mCo+bi0aWNWPyhCaFzoBerwIVd08N4ssZVjcZLna/5\n1QESqtIWyFZekkUt81HNVP50guT+lhdn3PuSZImH3kAN1FgWtudn4C+QmvzMbZyJf215BgBNPj12\nuh8uaapJYHl3HGqKJr8ljSRivN+BY+L85o/h3rcegJryLWuAVtmMba3EkZMace+W5diapS4U4hgB\nsaKjW47ElNqJeK2b9q1/N2PeEXlHiBACUzOHVak70DCYhKuzdRzqE9R6a29M4ZxjRokOaTweXB+v\nDVjbZoj4DDCsmN07wokjjgnGbSXIRCFaNZYgbKyqqgTaq+4Opmbg5gXfGHZxAAC5l06iL9sPBD//\n6AcmoK0hxayr0rHTfyulxjtfuXB2YB/TUAPkHQ/pLqcoCq6Z83m81rUC0+unlm3XVA03zbsWf1n3\nMOY1B88/aSQlNiPXgEJsA5QEJRdNU8X9VhUVcT2GrJMrkTv1O7DpJIbCiqNRlQ4nlTEtlbjinOHl\nluOZDgwZW2FU9pZpP3AsnnQBptVNxuyQMFgypkvtR6llfvyMFpxz7KhABYvrlmfX88WS21/ni+WE\nYGTFCDQmGtCSKs+h+dAxo5B5eRJeGKKqZykjAVPKjxhd3Q4vlxbkXZr1P7atEvW91YK8DcUs03pI\nmynoqg7HcyLLO0KE9zOSMQOXfmBCwPpUFAVnLRgl/q6NV+PauVehxqoWtdMAy1QPwbvUKyGARR0n\n4G8bH0d7upwggXIvwb6gcpjOdhy3f+kUhDlPYqaO044qb4RSio9Nvggv7HwFExpaA5/HTC0QTx7u\npW1qJuY2zRz2/PWJWnxy6uKyz6tSFlrrkujcUgN97Aa/R3XJjeTiIh111Th/8Ww8+vwmHD2Fkpii\nKHBcD4CCUU27n6fhoCkGiqtnY+KY6mFzH3Z3jRPaq1D7Vgp96BYKZJapBcIbANA7SGPSFSG9Cnij\nkXFVo8u2AfQ6bzjqK6GeBUPXcMmcU/DCE4/T79aswH6WqQXaKB83pTyMJHdPTJrhXRJrrCp05roD\nuQ8HGhF5R4hwCIJn/u8OYaSol3Tt8t9T7z57nz36VBzVNCvUnQr4mfEHEqX913eHay6eiQeeWheY\n+yObZuHIpvIOfg3VCZw4bQyeHqB61/EDYHGdd/wY/PTPWRQ3TII3SC3x0kXYpJrxWLVrNU4ffRLG\n1ldibFu4FT1hxL6Ve/L5U/YxPUpTVZw//Rj8YvkGuF10XsMWmJUpujiZPKqmbBscE1+c/CWMqKsu\n38awu/CWoer4ztHXI+fky/bTVQWktxVesg+jnWNw6YfKEyPlBWKlVU7eAK3+6Mx1v6sJaxF5R4jw\nHkLpy4n/fRAMbyiKMixxA8O7+A8WJnZU42sds/e8I8MHpx8JrO2GSzyMZuWP+xMzxtZhbGsVVm30\nPQSliYeXT/kICm4xkMAYhpHDtOfdEy47bRLufvRNXHRSeV7D28XMhmn45lFfxdeefx1AeF+Bs44e\nicqkhWOOaA58fvnpk/Dy6i6MaWh6R9LE1bEqhFG/rqtQMrUoLF+A6qnhYYFKSU2wKhHufeClm2Hh\nkwOFiLwjRHg/4GCw9x5Qmhl/uCFhxHHxxPMO6Hc01iSwamOv+LvU8k4YiVBteY5vXDoHb27uxbi2\nfVNIbKlL4tpLyj0Pe4vGZD14NnxYuMTQNZEhLmPhEc1YWELo+xO6poqSjPgwuvNT6yahVZuADVuK\n+MAJ4eWtnLwjt3mECBH2Cj/43AK4XnnSz26SzQ869kfM+72OxupgedecCeHW4XAY3VKB0S37ZnUf\nKBxKizZFAbhBP5znPWkkcN2xl9Me4lY4ZY6pot6RltSBW2iUIiLvCBHeA6geJpt4yqgavPRmF2aN\nrwvdfjBxqLnND0U0VvtW9e1XH79XMfxDFeYwCmXvJj555iS8sbGPlX3tObSkKsqwxA0A46vH4gfH\n/ntkeUeIEGH/4NjpLRjVVPG26qbfbZjvASI60JgyqhqTOqpx9NSm9wRxA76lezBx9NRmHD2VWsn7\nK6fz3SRuICLvCBHe01AVBR1N+67xfiDwqbMmY/32gVDVtAhBGLqGr148fKnZ4YTT53Xg2eXbUfUu\nqDjuDfzQ0qEYXBoeB3Qpd/PNN+PDH/4wLrroIrz++uuBbc8++yzOP/98fPjDH8Z///d/H8hhRIgQ\n4RDC/ClN+MjJ4/e8Y4T3FM4/fgx+eOXCQBOZQwEXn0wz6WVltsMBB8zyfv7557Fx40bcc889WLt2\nLa6//nrcc889Yvt3vvMd3HnnnWhsbMTixYvxgQ98AGPHjj1Qw4kQIUKECBHKILvQDyccsCXQkiVL\ncPLJJwMAxowZg/7+fgwNDQEANm/ejMrKSjQ3N0NVVRx33HFYsmTJgRpKhAgRIkSI8J7CAbO8u7u7\nMWWK322lpqYGXV1dSKVS6OrqQk1NTWDb5s2bd3u+6uoE9P0cI6uvP7RigYcronl854jm8J0jmsP9\ng2ge3znejTl81xLWSjV59xa9vdn9NBKK+vo0uroG9+s534+I5vGdI5rDd45oDvcPonl859jfczjc\nQuCAuc0bGhrQ3d0t/u7s7ER9fX3otp07d6KhYe/EByJEiBAhQoT3Kw4YeS9YsACPPvooAGDFihVo\naGhAKkVrTdva2jA0NIQtW7bAcRw8/vjjWLBgwYEaSoQIESJEiPCewgFzm8+aNQtTpkzBRRddBEVR\ncOONN+L+++9HOp3GKaecgptuuglf+cpXAACnn346Ro0atYczRogQIUKECBEAQCHvNBj9LmF/x2Gi\n2M7+QTSP7xzRHL5zRHO4fxDN4zvHYR/zjhAhQoQIESIcGETkHSFChAgRIhxmiMg7QoQIESJEOMwQ\nkXeECBEiRIhwmCEi7wgRIkSIEOEww2GTbR4hQoQIESJEoIgs7wgRIkSIEOEwQ0TeESJEiBAhwmGG\niLwjRIgQIUKEwwwReUeIECFChAiHGSLyjhAhQoQIEQ4zROQdIUKECBEiHGY4YF3FDmXcfPPNeO21\n16AoCq6//nocccQRB3tIhzRWr16NK664Ah//+MexePFibN++Hddccw1c10V9fT3+4z/+A6Zp4sEH\nH8RvfvMbqKqKCy+8EBdccMHBHvohg1tuuQUvvfQSHMfBZz7zGUybNi2aw71ALpfDddddh56eHhQK\nBVxxxRWYOHFiNIf7iHw+jzPPPBNXXHEF5s+fH83jXmDp0qX4whe+gHHjxgEAxo8fj09+8pPv/hyS\n9xmWLl1KPv3pTxNCCFmzZg258MILD/KIDm1kMhmyePFi8o1vfIPcfffdhBBCrrvuOvJ///d/hBBC\nfvCDH5Df/va3JJPJkEWLFpGBgQGSy+XIGWecQXp7ew/m0A8ZLFmyhHzyk58khBCya9cuctxxx0Vz\nuJd46KGHyB133EEIIWTLli1k0aJF0Ry+A/zwhz8k5557LvnTn/4UzeNe4rnnniOf//znA58djDl8\n37nNlyxZgpNPPhkAMGbMGPT392NoaOggj+rQhWma+PnPf46Ghgbx2dKlS3HSSScBAE444QQsWbIE\nr732GqZNm4Z0Oo1YLIZZs2bh5ZdfPljDPqQwd+5c/PjHPwYAVFRUIJfLRXO4lzj99NPxqU99CgCw\nfft2NDY2RnO4j1i7di3WrFmD448/HkD0e94fOBhz+L4j7+7ublRXV4u/a2pq0NXVdRBHdGhD13XE\nYrHAZ7lcDqZpAgBqa2vR1dWF7u5u1NTUiH2iefWhaRoSiQQA4L777sOxxx4bzeE+4qKLLsLVV1+N\n66+/PprDfcT3v/99XHfddeLvaB73HmvWrMFnP/tZXHzxxXjmmWcOyhy+L2PeMkikDvuOMNz8RfNa\njn/84x+477778Mtf/hKLFi0Sn0dz+Pbxhz/8AatWrcJXv/rVwPxEc/j28Oc//xkzZszAiBEjQrdH\n87hnjBw5EldeeSVOO+00bN68GZdeeilc1xXb3605fN+Rd0NDA7q7u8XfnZ2dqK+vP4gjOvyQSCSQ\nz+cRi8Wwc+dONDQ0hM7rjBkzDuIoDy089dRT+NnPfoZf/OIXSKfT0RzuJZYvX47a2lo0Nzdj0qRJ\ncF0XyWQymsO9xBNPPIHNmzfjiSeewI4dO2CaZvQs7iUaGxtx+umnAwDa29tRV1eHZcuWvetz+L5z\nmy9YsACPPvooAGDFihVoaGhAKpU6yKM6vHD00UeLOfzb3/6GY445BtOnT8eyZcswMDCATCaDl19+\nGXPmzDnIIz00MDg4iFtuuQW33347qqqqAERzuLd48cUX8ctf/hIADX1ls9loDvcBP/rRj/CnP/0J\n9957Ly644AJcccUV0TzuJR588EHceeedAICuri709PTg3HPPfdfn8H3ZVezWW2/Fiy++CEVRcOON\nN2LixIkHe0iHLJYvX47vf//72Lp1K3RdR2NjI2699VZcd911KBQKaGlpwXe/+10YhoFHHnkEd955\nJxRFweLFi3H22Wcf7OEfErjnnntw2223YdSoUeKz733ve/jGN74RzeHbRD6fx9e//nVs374d+Xwe\nV155JaZOnYprr702msN9xG233YbW1lYsXLgwmse9wNDQEK6++moMDAzAtm1ceeWVmDRp0rs+h+9L\n8o4QIUKECBEOZ7zv3OYRIkSIECHC4Y6IvCNEiBAhQoTDDBF5R4gQIUKECIcZIvKOECFChAgRDjNE\n5B0hQoQIESIcZnjfibREiHC44ZZbbsGyZctQKBSwcuVKzJw5EwBw3nnn4UMf+tDbOscdd9yB8ePH\nCz3rMHz0ox/Fr3/9a2iatj+GHcDOnTuxbt06zJ8/f7+fO0KE9yOiUrEIEQ4TbNmyBR/5yEfw5JNP\nHuyh7DUefPBBrF27Fl/60pcO9lAiRHhPILK8I0Q4jHHbbbdhy5Yt2LZtG6699lrk83nceuutME0T\n+XweN954I6ZMmYLrrrsOs2fPxvz58/Fv//ZvWLhwIV5//XVkMhncfvvtaGxsxIQJE7BixQr89Kc/\nRV9fH3bs2IGNGzfiqKOOwg033IBCoYBrr70WW7duRVNTEzRNw4IFCwI9ijOZDL7yla9gYGAAjuPg\nhBNOwJlnnokf/ehHIISgqqoKl1xyCb797W9j48aNyGQyOPPMM3H55Zfj/vvvx9///ncoioKdO3di\n9OjRuPnmm2EYxkGc4QgRDk1EMe8IEQ5zbNmyBXfddRemTp2Kvr4+3HTTTbjrrrtw6aWX4vbbby/b\nf+3atTj33HPx29/+FpMmTcLDDz9cts/KlSvxk5/8BPfddx/uv/9+9Pf348EHH4TjOPjjH/+Ib37z\nm3jmmWfKjnv22WfhOA5+97vf4Q9/+AMSiQRaW1txzjnn4Oyzz8Zll12Gu+66Cw0NDbj77rvxxz/+\nEQ899BDeeOMNAMCyZctw66234r777sO2bdsOSy9DhAjvBiLLO0KEwxzTp0+HoigAgLq6Otxyyy0o\nFAoYHBxEZWVl2f7V1dUYN24cAKClpQV9fX1l+8yePRuapkHTNFRXV6O/vx+rVq3CkUceCQCor6/H\n7Nmzy46bNWsWfvKTn+ALX/gCjjvuOFxwwQVQ1aCNsHTpUuzYsQMvvPACAKBYLGLTpk3ieN4+debM\nmVi7dq3okxwhQgQfEXlHiHCYQ3YrX3PNNfjWt76F+fPn4/HHHxfNPGSUJqSFpb2E7eN5XoCIS0kZ\noL2M//KXv+CVV17BP//5T5x33nl44IEHAvuYponPfe5zOPXUUwOf33///fA8b7fjihAhAkXkNo8Q\n4T2E7u5ujBs3Dq7r4pFHHkGxWNxv5x49ejReeeUVAEBPTw9eeun/t3eHOAoDYRTHHyGYJlwAMAjg\nAFROSC0STCWCIJCYBhwOwxEqegIkuqLBbRN0LQaBxkBZsdkaDJutmeb/05PJ517eZCbz9bYmSRLF\ncazhcKggCOQ4jm63m2q1mh6Ph6SfVv97VJ/nuXa7XdH+z+ez7ve7Xq+X0jTVYDAobX6gSmjeQIUs\nFgvNZjO1Wi3N53MFQaAoikrZezqdKo5j+b6vTqcj13XfGnq329V6vVYYhqrX6zLGqN1uy3VdrVYr\nNRoNLZdLZVkm3/f1fD7leV7xVWq/39dms9HlclGv15MxppTZgarhqRiAj1yvV6VpqvF4rDzPNZlM\ntN1ui3fn/3U4HHQ6nbTf70vZD6gymjeAjzSbTR2Px+J/4tFoVFpwA/gbmjcAAJbhwhoAAJYhvAEA\nsAzhDQCAZQhvAAAsQ3gDAGAZwhsAAMt8AxJ5C+54P8QOAAAAAElFTkSuQmCC\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAe8AAAFnCAYAAACPasF4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzsvXe8XVWZ///e5dTba3pCQiAJCSWE\nIJGmoSSgjsg4gmCb4Tf+dCwURUdEQXGs41gYFQvDiIyIiKIIJIAgEBJCgJBKertpt59z76m7fv9Y\nu55zboiQBCL783rllXt2WXvttfden6et55Fs27aJECFChAgRIhw1kF/vDkSIECFChAgR/jZE5B0h\nQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8IESJEiBDhKENE3hEiRIgQIcJRhoi8I7yp\nMW3aND796U9Xbf/iF7/ItGnTQsfdcMMNoWOWL1/OBz/4QQB2797NCSec4O3btWsXH/vYx1iwYAEL\nFizgkksu4bHHHgPgpptuYuHChSxcuJCZM2fy9re/3fudy+VC19A0jfvvv/9vvq/Vq1dz1VVXHdSx\nDzzwAF/72tde9bVcvNbz3wi46667+P73v/96dyNChFeE+np3IEKE1xsbN24kl8tRX18PCBJas2ZN\n1XErVqxg/fr1IZIeCZ/97Gd597vfzW233QbAqlWr+PCHP8zDDz/MV77yFe+4+fPn8+1vf5vTTjut\nZjvr16/n/vvv55JLLvmb7umkk07i9ttvP6hjly5dyvnnn/+qr+XitZ7/RsAHPvCB17sLESIcFCLN\nO8KbHm95y1t49NFHvd9LlizhxBNPrDruuuuu4+tf//pBtblp0yZOPvlk7/fJJ5/M4sWLGT169EH3\nq6+vj09+8pO89NJLXHHFFYCwAPz0pz9lwYIFmKbJypUrufTSS1m4cCEXX3wxS5cuBYRV4IILLgDg\n1ltv5atf/Sqf+MQnOO+883jve99LT0+Pd53ly5czffr0qmu98MIL/OM//iMXXHAB73vf++jq6gKg\nu7ubD3/4w1x88cWcf/75fO9736vZ18p7ueqqq1i4cCHz58/njjvu8PatXbuWSy+9lAULFvCBD3zA\nu85I26dNm8b+/fu9893fy5cv5/LLL+fqq6/mM5/5DAD33nsvF110ERdeeCFXXnkle/bsAcC2bb7x\njW8wf/58FixYwC9+8QtvrL74xS8CsH///pD15MknnwTAMAy++MUvsmDBAi644AI++clPVllMIkQ4\n3IjIO8KbHhdddBF//vOfvd8PPvggCxcurHmcbdssWrToFds855xz+PSnP82dd97J1q1bARg1ahSS\nJB10v9rb27nuuus45ZRT+PWvf+1tt22bxYsXoygKX/7yl7nqqqtYtGgRH/3oR7nppptqtrVo0SJu\nuOEGHnvsMdra2rjvvvsA2Lp1Kx0dHYwbNy50rVwux8c//nGuu+46Hn30UT70oQ9x9dVXA/C///u/\nzJ07l4ceeogHHniArq4uLMuq2VcXP/nJTxg/fjyLFi3il7/8Jd/97nfZt28fIISiq6++msWLF3P+\n+edzyy23HHD7gbB+/Xouv/xyvvvd79Lf389Xv/pV7rjjDh555BEmTpzIj3/8YwD+9Kc/sXr1ahYv\nXsx9993HXXfdxerVq0Ntff7zn2f69OksXryYn/3sZ3zuc59jcHCQJUuWsHv3bhYtWsQjjzzC1KlT\nWbly5Sv2LUKEQ4mIvCO86XH66aezefNm+vv7KRaLrFy5knnz5tU89oYbbuA///M/KZfLB2zzO9/5\nDldeeSUPPPAA73znO5k/fz533333Ienv2972Nu/v+++/n4suugiAOXPmeNppJU477TTGjRuHJEnM\nmDHDI85ly5bVvNcXXniBUaNGceaZZwLwzne+k127drF3717a2tpYsmQJzz//PPF4nP/6r/+is7Pz\ngH2+8cYb+dKXvgTAhAkT6OjoYPfu3Wzfvp3BwUHOPfdcQJitb7311hG3vxKSyaR3P21tbbzwwgue\nteO0007zxuepp55iwYIFxGIx6uvreeihh0LWlkKhwPLly/nIRz4CwKRJk5gzZw5PPvkkra2tbN26\nlUcffZRiscg111zD2Wef/Yp9ixDhUCLyeUd400NRFC688EIefvhhWltbOeuss1DV2p/GzJkzmTt3\nLnfccQezZ88esc1EIsFVV13FVVddxdDQEIsWLeLrX/8648ePf80TfXNzs/f3Aw88wJ133kk+n8ey\nLEYqVdDQ0OD9rSgKpmkC8Mwzz3gEFcTQ0BBdXV0hC0Q8HmdgYICPfOQjWJbFV77yFXp6erjyyiv5\n1Kc+dcA+r1mzxtO2ZVmmt7cXy7IYHBwM9U1VVVRVHXH7K6Gpqcn72zRNfvjDH/L4449jmib5fJ7J\nkycDMDg4SGNjo3dsOp0OtTM8PIxt21x++eXetkKhwBlnnMFJJ53EjTfeyK9+9Ss+//nPM3/+fG66\n6aZQexEiHG5E5B0hAnDxxRfzve99j5aWlpo+2yCuvfZaLr30UsaPH19z/8DAAC+//LKntTY2NvK+\n972Pp59+mk2bNh0yLa27u5sbb7yRe++9lxkzZrBjxw4WLFhw0OcbhsGaNWtqCiGdnZ1MmTKF3//+\n9zXP/ehHP8pHP/pRtm/fzr/+678yZ86cA17r+uuv58Mf/jDvf//7kSTJG4OWlhYymQyWZSHLMrqu\n093dPeL28ePHI8uyJ3xks9kRr/nQQw/x+OOPc9ddd9Ha2spvf/tbHnjgAe+6g4OD3rF9fX0kk0nv\nd1tbG4qicN9991FXV1fVtrs6IJPJcMMNN3D77bdz7bXXHnAMIkQ4lIjM5hEiALNnz6anp4fNmzdz\n+umnH/DYzs5OrrzyyhHNuKVSiU9/+tM8/fTT3radO3eyatWqEaPKR4KqquRyuZoa9cDAAOl0milT\npmAYBvfccw8A+Xz+oNpevXo106ZNIx6PV13r5JNPpre3l1WrVgHQ1dXF9ddfj23bfPnLX+aZZ54B\nYOLEibS3tyNJ0gH72t/fz6xZs5AkiT/84Q8Ui0UKhQLHHHMMo0eP5pFHHgHgd7/7HV/+8pdH3A7Q\n0dHBhg0bALjvvvuQ5drTWH9/P+PGjaO1tZXBwUEefvhhb2zmz5/Pgw8+iKZpFAoFrrjiCjZt2hQa\n93PPPZff/OY3ABSLRb7whS+wb98+7rvvPn70ox8BwgoyZcqUgxrvCBEOJSLyjhABkCSJCy64gLe+\n9a0jkkEQ//Iv/4Ku6zX3jR07lp/85CdeVPiFF17Itddeyxe+8IVQBPrBYM6cOfT09HD22Wd72qaL\n6dOnc84557BgwQIuu+wy5s+fzymnnOKtPX8lLF26NOTvDl4rFovxwx/+kFtuuYWLLrqIT3ziEyxc\nuBBJkrj88sv53ve+50W4z549m3nz5h2wr1dffTWf+MQneNe73kWhUOCyyy7jS1/6El1dXfzgBz/g\ntttu48ILL+TPf/4zN998M5Ik1dwOwvJx88038+53v5tUKuUt8avEO9/5TjKZDBdccAGf+cxnuOaa\na9i/fz/f/OY3ufjiiznrrLO48MILec973sN73/teTj311ND5N998MytWrGDhwoW85z3vYcKECYwZ\nM4bzzjuPdevWceGFF3LRRRexZcsW/vmf//mgxjxChEMFKarnHSFChAgRIhxdiDTvCBEiRIgQ4ShD\nRN4RIkSIECHCUYaIvCNEiBAhQoSjDBF5R4gQIUKECEcZIvKOECFChAgRjjIcNUlaenuHD2l7LS1p\nBgcLh7TNNyOicXztiMbwtSMaw0ODaBxfOw71GHZ0NNTc/qbVvFVVeb278HeBaBxfO6IxfO2IxvDQ\nIBrH144jNYZvWvKOECFChAgRjlZE5B0hQoQIESIcZYjIO0KECBEiRDjKEJF3hAgRIkSIcJQhIu8I\nESJEiBDhKENE3hEiRIgQIcJRhoi8I0SIECFChKMMEXlHiBAhQoQIRxkOK3lv2rSJ888/n7vuuqtq\n39KlS3nve9/LZZddxo9+9KPD2Y0IESJEiBDh7wqHjbwLhQK33HIL8+bNq7n/a1/7Grfeeit33303\nzzzzDFu2bDlcXYkQIUKECBH+rnDYyDsej/Pzn/+czs7Oqn1dXV00NTUxZswYZFnm3HPPZdmyZYer\nKxEivGmhGxZL1+6jWDZe76542NuXZ822/te7G0cNXtjYy879wyxduw/Lsl/v7rxq9GWKrN8x8Hp3\nA4D9AwVWbekDoKyZPPdyN7Y98tjmSzovbOw54DFHGoetMImqqqhq7eZ7e3tpbW31fre2ttLV1XXA\n9lpa0oc8Z+xICd8j/G2IxvG143CN4d2PbOTXizdw3twc11x+6mG5xt+Kf/nm4wDc/+13oSiHTn/4\ne3wP9/Tm+NEf1ni/48k4F8075rBe83CNo/vcf3XzQpobEoflGn9rX+79+jv4+d0vsmzNPmRV4aK3\nTq55/I9/8SzPv9zNdVecytvnTHjF9o/Eu3jUVBU71JVuOjoaDnmlsjcjonF87TicY7hhu9BwN+wY\neMM9p737syTjh2YK+nt9D7dWaKobt/dz2tS2w3a9IzGOXXsz6K3pw3qNg0V3zzArN/YAsGnnAKcd\n117zuA3Oc3hh/X5mTWw+YJuHegzfUFXFOjs76evr8353d3fXNK9HiBDhtcE180lIr3NPqqEZ1uvd\nhTc8SroZ+m2aR/+YvZFcOJZtY5jiG1EPYAVqrheWgsHh8hHp18HgdSHv8ePHk8vl2L17N4Zh8MQT\nT3DmmWe+Hl2JEOHvGq6LTnrjcTdGRN6viHIFeRtHsc/bRb6kv95d8GBaticQqcrIH0mLY+bP5N44\n5H3YzOZr167lW9/6Fnv27EFVVRYvXsz8+fMZP348F1xwATfffDOf+cxnALj44ouZPLm2ryFChAiv\nHW9E8tYj8n5FaHp4jEzz6CfvQun11byDQWeWZeP+UuSRddn6VAyAzAE072x5iKZE4yHp48HgsJH3\nrFmz+NWvfjXi/rlz53LPPfccrstHiPCGwf6BAo3pOOmk+Nx6MkXSCdWbEGqhe6BAQzpGOukf0z1Y\noLk+QSJWHbiZzZUxLZvWxmRou+Wazd+A7H0kzOYDQyUUWaKp/rUHSFm2TVd3jgmj6pEliZ7BAk11\nCRLx8PMoayZ9QyXGtde9pusVSjq7e3OhbYPDJbJ5jaa6uLetN1MkGVdoSMcrm6BYNtiyJ8u49rqq\ndwOEANWXLTKmrbqvA0Ml4jGF/mzJu+dK2LZNV0+Ose11ntnZtm329OUZ116H5IxTXeBdz5cM9vbl\n6WxJeedYts2W3VniMZljRjfSkynSmI6FYiJ2dQ8zpq2OmFqbZGudUwslzbdmmJZV9bdl2WzenSGV\nUEknVOJxxfuOhgo6fZkidakYqYR/naV7V/B/G+7lIye8n4s7zjng9Q8VjpqAtQgRjkaUNZMbfvYs\njXVxvv+pswD499uWIQG3//v8mufohsXNd6xg9nHtfPQfZgLQny1x48+X8455k7jk7ClV51z7388A\n8D8VbbpKhvw6cLdpmWzL7mBq85SawsOR0Lw/++OlQPW4vBosfm4X9z6xlcvPO45Tj2vn33/6LBOm\nZ7A7N3LdnH+jOdEEwDf/70V2dg/z7Y/No7059aqvd/MdK+jLlkLbNuzKcO2tS0L3c8Mf7sYup/nF\nxy6vauOexzezZNdKmuoVvnvlZVX7n1i5h3v+spmvXHU64zvqve2mZXljB3DlBcdz3pzxVeev2TbA\n9+9dxZknjuaqd5wAwOMv7uH/Ht3E+88/jtOmdfLvP32W9iZfcFi5uZdfLd7I208dxwcvnAbAqi19\n3HqfiKq//v2z+c7dK5k+sZnPXSFWSGzcNci3fr2SudM7+fgls6r6kc2V+ffbljF1XBM3fHBOjdH0\nEdT8g0vvXFJ/YVMvP7l/rbddSuZomLUSuXkaVqaTz922jHHtddzy/73FO+avu5cAsKJ7JRefeGTI\nO0qPGiHCYYRmiAlhKK8BYDj+tQMZP4tlg7Juhvxre/pymJb9N/vcfBPhkWfvezf/ie+v/CkrulfW\n3K8bZs3tbyT8cevDfGHJLWimxsrNIsh21ZY+9vYXIFair/FZ+kuDdA3v8c7Z2S0ijQcOMrhJt2qb\nkSuJuxZ2De0hPmkDieNfrLm/qzdH4riXKI15ofY1MkVsoKsnrOFXmrbXjrAuf9veLADPrNnvbXtx\nUy8AK17uYSivIaWHyE+7H7mpx2lLRG4/8aI/Zv2Be3VzAGzYlfHb3LsRKV5gxYaemv3oHiwCsGVP\ntuZ+0zIxLfG+BX3uZoC8NSe+oC9bDJ0rN/Wjy3kxxoo4d09fPnRM2RDPOqkcuSVwEXlHiHAYURlf\ndDDapghSssnGtzGsiUm1NyMmt6DPc1XvWnoLB0524pL366F5P71HJF7al+/2tlkBf+PR4PN+ZOcT\nDGnD9BbD45zJlVEa/WVcRaOaaJUDBEC52DW8m2v+egNL9jxb+wDJIjZlNXJzd2iz+1yf6FpywPZ7\nC/6qHsuuHm83IK43EyasSvIeyVRdy6Li3rdhWpiWTWzsVtHGpA0j9jN4vf4KoaW/OMCSwu+Jz3hu\nxPMrz6nEf734E7723HeBcLR7Tisi1WWIH/8Cw8ZwVV8AJFXz/pbragsHZVMck1CqXReHCxF5R4hw\nGFG5tEc/iKU+Zd1Ead9LpvU5/mfdrwF/cnU1hf35Hn625k7+47n/8jQGoCoDl6d315hkNw5sYU3f\n+oO+l1eLxri/TjVI2G/0pWLuhAygW+EI6d5MESnuE0ZBD5MfgHIQEtOjO/8KwIPbH625X2ndj9q+\nl8TxYeuFu7xpW3YnALZR7QEtaQZFxSfvWn0cibzzAQKT6rIMpNbXJH+5xj2qTuCXYdqifcVpq0Yf\na12vp6Ivq3qFCVtOjEzQwf6XtDD5dg3vYcfQLnoKfWim7l9L0fn+y98mOfNZlOZetiUfreoLgBQr\nB/7WqIWyeeSj0CPyjvC64I2UZvBwwqwgU10/OPKWG4RWtze3D/AnJ3epUG9RTMq6pYcmm0rh4EBW\n8x++9DNuW/2/r9ifkfBiz2q+8dz3KRrVpBCc6IPEFyTv16p5W7bFrSt/7hFg9X4bpW0vcsv+mvtf\nCbuGdnt/l4zw5NybKSLFAuRtVCeRMi2bIW2Y32/+MzktX7U/eI2JDeOq+g4gN4nnbNvhB1jWTWzb\nJlN2TMuKUUWufZkScr2vKeZr9NGNZu/LhImx4JiWpbosyZnL2Bd/gd3De6vOryWfuEuuTMuirJtI\nqmjLNqsDNF3BsxAwZe9zTNKphAgEfMkhb9saWRjakFvtPefKe1m+z3cZ5PScZzavJGJNyZLT86G+\nACEhDTV8zu83/5lfrL0LzXnHa1lgDhci8n6TwPUvWrZdMYGaNY97pW2vBat61/LJJz7PtuyOmvst\n2+Kl3rUU9dIBr23VEAAOVV9/uf433LT0m6/6fLcfQfI2TCtEriMJMJpmIsUFIbYmW4Cg2VycP1jy\n/YHByaaSED2zeeC3bdvkdJ9MKv3ouiGIwbKtmtqWi9vX3sXu3F5W9673zgHxXILm/JAGG+hfppzh\nzvX30F3oxbQsBofL2LZd9QxFX6rHqrfYz4bBzdy/9aGq4yzbxjBMYsesIzaxtrm2ss18SQ+ZTHcN\n++RdOSn3ZkpIcX/cCjUEGNO0+c7z/81fup7imb3Lvevphskftz7MZ578Mn0lIaTJkh+xXiwb7O4f\nRG7sQ2l0yLscDnzTdJMhLYdhi/5KEhR1v49lzWT7/iGkhE/Y+RoCxLDdR+yYtfQMD4W254o6iZlL\nSc70a04MlAZDxzyzdzm7zZeRm3tInb6IPbl92LbtRZAXk7tZMfCMr3lXQB2zleuXfImCXgwJoK5F\nJp2IYds2e/Ou8CXh2pJ0w2JgqESuqDNYyrIz/gyJ414C/JgDwzK4e+PveaFntX9fWt5/xnJ1v7Kl\noWqzeby25m1ZNn/peoqVgfaPJHlH0eZvAnQPFvjCT5/l4jMm8fLOQbbvG+J//n0+Dy7bwX1PbuPm\nf57LxFEN/PWlPdy5aCPXv382MyYJ0vjNXzbzyIouvvWxeXS8hsjZIO7fIibbv3Y9w5SmY6r2L937\nHHdv/D2N+jF0r5zOj649J7QsA2BTV4Zv/t+LfPySWcydLrLzPfp8F3c/tpkbPjgHPdVNS6KZ0XWv\nLnPfc/tFAJBhGajy3/aZvLxzkO/cvZIPLZzGceP9VIr5khEycRumTUyt1iZKuomUEGSwfVeZNW39\nXhCNaxbvCfgyQ5p3gBwf2LaY4eQQ0IYkSZR1k49/90nOOGEU557lB9Zc96On+MAFM5h/6niG8hrX\n3LqEc04eiz3xRV7u38R/nHUjsQOMgaZb/P//+SRnnTiGf3nHDD73k6VkpN0kRCBxyKToE7PFnwZv\nB6A50cSG5Z1s2JUhEVcoayafuewUZk5u5cWe1WzdrPDw091V0dvd+XDw0n/d8xLb9g3xb+85ke/+\n5iWuuGgikmKCbGFaJorsE+SqLX384Herue59JzNrShvb9w1xyy+fB+CbH5tHZ3OKVV27vON/8dBq\nxqnT/WsnXkJp9f3QtUzSmfKgR3j3P72de34Dn79iNt/69UpSpz8ROvalbfvZ2TlMc32cz922DOnY\n5SSm+wKQVKHxLVu3n98/v5LkTH9btpynLp6mpBlc/+Ol5EsGiVk+mQxr1Zp3X/ol1NR+8oqBbvhR\n0kPlAnJdmNAHy2F/76833AdAfLLQqH+y5EFSPafQ0SKeUXncc7yUAzlZ+x5iEzZj2LBpcAu5cnXf\nDNPi9kWrKDrmckm2QBZC4pduX06PE6TWcEwXeJ+5ze0PvsykUQ3stzZXxRIM6znyJdFfqYZQ8Z3f\nPUeb7FtB5MZ+5PQw2BJIdugeilr1+bWEuMOFSPN+E2CjE7X50LM72b5PfJCWbXPfk9sAPzr0waXC\nf7Z07T7v3EdWiIIxm7p8Te+1wtXmRlp77EbuZhFmuv6hamn2ryvFMfc+4ZeSve+vIjBm+cbd/PdL\nv+CW5f/5mvsa1BoPFktWi34/tGxnyOddKOkhzbsye5aLkmYguf492eLhZ3d6y1hcTb7HMZvHlXhI\nU3DJ0bAMFu34CwPNKwDxvIediPdn13fTlfMjfVEML3p2l6O1PLVqN893v0TeKJAp1Q7ScbEvI/Yv\nWbMPy7YZGCqHTI1lI0jezrNP+pO1bulelHDZuc+la/ezYWAzt6+9i8cHBUm8vCus+e2vIO91OwYp\nlk1+/dgG1NHbWbRGmFslSZivg1i0XBDzn5buAMS6eq/d/gLL973Attxmb5tml0NLBOzOTeJ/UwgE\ntczmPaVe729TEtaRxc+NUIBJ0dm0O0NXTw7dsFCawgFykmqA5L87f1ixKqQVA2RLeedehCY7arRN\nLOW/vzm9uo+u8UFp2093xo84HypWa+lBzTtoNZJi4t4yWYOd3TnvGVYimbb454unM2/mKN4+2yfI\nn6/9FXvG/L7KvVHWTZZt3hHaJsU0hgs6PYNFL9+BlvYtJA2NYoy27xtClqvzIeS0PANDzvtYg7yF\nZu5bshLTxfcjmXHn+v67nCtWZ4orRWbzCIcSlZGiUrzAQCHrmbfcCdUNsKn000LtwJRXCzenkSzV\nfv08U63j56t15ZST8CS47MM1t9nqa5N+g6biVxOIogdyJQfHslAyQj5vbQTyHtZySJK7QNsMCS8e\neRd8Yqg1BtlymKzcyF8Xe3P+RCkpJprmulWcyzb4wlqmXE3ewTEKEpebgSpE3gEBSDMsUDXkQKT2\nYDHnBWC5SMQVT4iT04JUKpPT7A1EsQ+X/D70JlcTm7iR4lh/nfJAKSx8uglz+lIr+dnqX4aEqoHC\nEHe+fA92zH+PJMW3mtiYge2Oz7aW5q0NBI4TRDE40lI/xaA3U/RiG1yhIISA1hcb5QsBVkEEBA47\nhNubKSI39DM0cTGmFCDvcjUhm5Lfn64B35ozXK6+n6DmXWt5m6aJ96w0AnnbisbZJ43lX981k7NO\nGlO1X2kJC2NlKUd8imOSdueCWJk9TuKaGcc0ITf1hPz6l54nhILebBHd9L8LN2ZgWM95Yyyp1fcg\nqToZR8gNCktSsQXbkpDrhjyXVqYQHs/GeEOkeUc4tIhVJNxPnvIUNy3/ukfq7oTvEnStmsG1siu9\nWnjBOCO8fqZDDF6QTsW1C3qRVervUNp3UyxXTxSGUjs46GAR/AC1V6F5uzm7K8k7XzIOSvMeChCv\npBj0Z/0J1jQtbNv2JlLN1MgXq33Kg+UwWelGONZhMKhNK4bXF89H3uATT7YGebtL2AAKZoA43Ykx\noKGEzeYWyROXED/Gj3LPFKsrMCVisqfpuYFKlQUt9hd88t7W7U/87uQaxEDRHw/TMulveB65foBy\n0xZW9a1DM/yJfqhUYwJWDIYdTasQ8C3bloRqJ8g774zhPV+bQS2gPTvrg0dapy8pOn2ZkhfbYBuB\n4C6XuFS/j5Ll77cKIrmKaxbvzRRDhKZaooJXvkLz3jS4FSvhH7c34z/zXMDEbmbbkGyZTDDOooal\nwRVkhgvV38yY5Hh0S/e+p2DSFu/8eFhrjU3Y4AluFB33k6qxs1tsax6TJTFNuLdkW4yHkhTj25sp\nUTQDz0lLOPeV9yPTlWrNWVI133IQ0MylvTPBjCHFyyROehqAgYL/DZzVeiGtyRaKRumIBeNG5P1m\nQ0CajKsVmrdyZDXvkczmXiCRM2lVLrfaNbybItmQ9haEJudqbg9ix9CuUDRxELkAMb0as7k7gcdU\nudpsHiDQkTSUIT3Qf8UMBVaZlk1eL2AENJ+hgJbktl+pLZtWOFguUwoICLJBWQ8njwlOpBkt7PuE\nsHBQNH1hyU0sEgzyqQxYq4zyHa6hESZiCrudSHsMYbKsDCQKmnF39PmWCKxqrTWoea8f2EivuoHE\nCc+BLO47GImdr0HekmJ4VoVcYLy1DXNRiFN0iNHVzmPHrmJDYZV/vqPlDeVqvE+2BIpBT7ZAr5sg\nJBCZHdNF/Elw3EzE3+ZQC9aQKBE6rPmad5D8k5YgviB59xT6+MHKn3r3D9CdC0SmOwKKOdSKtuUU\nVCsVGsOagVmOcDGU10LzjLZsY2voAAAgAElEQVRtFi0J0Qc3ULJWamC5gryDEfbGUKM3Bjv2i/dR\nSfnHj5VEPIIuFVBkib5MMWzCdsZzqJwj4zyDWj5vggKSs9/oHYdeSHrjLzlj1p0V42V0T6TdmEZK\nTWLaZkjjP5yIyPvvEAOlQTQzaEoNkETg5Yx55C32u2ZzzSpXSemHMsmHa3IdSZu3bLe/Yn/l8ifX\nZOx+XJU+tqLtE9NI0dI/X/Mrfvly7dz6w4Go3FdjNvfIW5FCVox8yXCehY0yagd7ctVLbwAKhk/e\nUkVErGnZVcQ8FCC/2uQttO6gmX446ANWzCrNO6g51zKbBzX3vOFf39e8S9imgm0qlMxqn3cQtZaa\nxWOyPz6KDtgVS+L00Du6dzAgyFnV01qwv7XKoxYD1oOc5vfX1WpRDE+wcTXvuvxUrFwrshX3rDWu\nEKS2+W4J28bT8mwIERtA0m5Ckm36snl6B4uOUO2/N0nTIe9gwJfzHWtbT8Z2hJu8JvrQmy2FiClh\nC7N60KKU06sF3IGCP0Yll7wHRoMZQzHTDGnDXpayWm4C1zIwVNA9rdUcGIXZN576mMidviWzXRxb\n49sXAl8wsMA/xsoJ8pcSBU/zjiXEOOp7jmVS8ngAslqWtqYkvZli2IK2XUT2DRQDgmhgjLRts5x7\nEGOcSqj+flMV30dIKLTodSL0bSNGb6ZIWhWBevkaY3M4EEWbH4XY1T3ML/68nk9eeiKdLeGi9nm9\nwJeWfoPx9WP5wunX8H+PbOIvL/oaZnACiFVq3g5Db2m+l889bdO89VLv2B/9YS3nzxnPFRccf1B9\nfOCZ7by8c5Dr3z/b+1DveXwz2ZxGLqWBAiDxg3tXsdkpSPCpfzyJyWMasQhr3pWlI7tdf6/zcXUP\nFkKBa3nTn4SeXtPFI8v3ceOHTvMi1otGkUw5S4MzET350h4ef3EPLQ0Jpk9soXNyteb98PKdrHi5\nhxs/dNorWiFcYUOpMJs//uJu9vUXkFI54pM2cHfXBo4ZfQ0dHdNC5xfMvC9WKz7hqopMMbGXb6y4\nN3R8vkLzfu7lbh7ZvAncVNWySV+2xH/+5iVQyySmr6Bo+WSlxk1PAHIzuAXJ+4k1W2nMdHHh3An8\n/qmtDA6VGTfL13SDxPenZ3aI8+MlbC2JpOrsGxxC003iMSUsSOL4CZ3+j++oZ3dvDqV1L08Ul1G2\nXQ1JRBkPFzRu+eXzDAyVuPyiseLWJBnLtugeGgScSHTJH3NbjyPFNJ7esI1ZiX5mTWmrGVQUFCCW\nb9hLYgZItoK2+VSSJz8VIkPN0kgACScVpmzF0S2DsqlVuUJsSwJLDWt5FRpfPpNAaRVBcXv6ZEa1\npukPHJM0WxlmK/Gpqyi91Iytpfz2jJj4B+zoHeCWX66gN1MiMcrC/WpiOLWoy4M8u24/v35sM23j\nstDqdlJEUvfkMlzzvb/S2ZRkz2CWeDNgim9GNtPY2GTKQ7SlWnjmZT8S37+voNbqru0W53em2wG4\nc/09nNwxq3a8i2yKNtzgMMdaUVpzJraewDZU1NE76Fk3FmjwrmFmOmlKNEFJBLt2NI9l3fYBVm7d\nBzEorT4Lu1SHjEJ/UVhrmuvj5J0xHNf/DrYMlGHKWi/4rqkuTlmvuIf++fR0OMl0VIP+vAZpMUZ9\nmRId44UrIK8XSODniT9ciDTvoxD//fs17O7Nc/+S7VX7smUhDe52tJYgcYP/UUE1eXuk5Ex++/rD\n2vdjL4i2Hti6iJuWfjNkuq3EH57ezoZdmdBktvi5Lp5d3+1V7zEsg1Vb+ymUDTI5jXXbhfZkOaTq\nkXeFGd9dJuVOYNv2DbFuh29CzZm+VnnnY2vZ11/w2ga8NchlS5DDLxdtpKsnx+qt/fz2iS0hs7mr\nzdz7xFZ27B9mYFhM/I/tepIvLf1GTbO6YYj+xlQ51Hd3PIPEuLXGWveiJTRZ24gJE51kkUooJOMK\nQx3LveMkXZBVIUBGmmFy2x/XMRQ0dct+pLrascf3IzqIxy3vOXmacUwTpk5bwlSKvLBR+JT/vHQn\nz6zd7yWPUWWVklXh/5QNpJiOrSWxTRXd0ti0W5hcS3p4vNpTbRhogM2oVnE/8amrKdhhbV+KldnX\nX2D7viGyeY2N+4VmO8FJbtJfEM9/zrQOkulAycfhZmxLwlIK/OB3Ivgp6Au1ikIjdCO1g+M1Sj8Z\nu5wS5tsa5OvmsVY1oRWu69/gCUFuxrPyunnYhhr2V1eQt2vilhQD07LpaEqScuSQs8fOo8mY6B3b\n0C76Kak6tiWDrXiad1e+i+37hsgVdU8rbUk0M8Y6EXO4ha58F89u3k6uqLN70DeB1xtCECrbBbbu\nzrJsXbfXx6ljBOm675rrLtm63/fne/0P3CMO8br7zh53BhMbxmNjM1QerhKgZEMoIak5j3uBeZKq\nYVsSdrEejDj67uORZBu5LksqoVK2xHc0a2In582ayrFNk9kwuJlp08XzH8yL91wkh5EYk5jAgN6L\nlMgzujXtPceGRJpPvWc2tiV5yk1bU9IXnB3yTlvtTFBO8PqWLfrfaaGkM7vzJE5sn0FnXTtHAhF5\nH4UYcgJC6pPVfiP7gCUvCJnN3UxIru9VqTRlSbVNzot2Pk5faaAqorkWatbudYSDsiHuo9NZu+ul\nAK2INh9Z8xb34i6BclG2AxOxc7/BW3Ozk2mmVm1Wl02yAZPycCk8ybhc/IctDzJQGqyZdcowLZAN\nCsmuqr7Hj3+exPTnvd+1AuLKtiBDu5T2+pSIKSiyhGwEAn1KQrovB8hINyykZA65MbBGOKC9h9Jo\n6qItJRYgb9MCbKRYmeZEI7aeQIqXqtJn7sntJ6HEmdQwAc0uhd6Vjsni2tZwC5gKyCb5ongPKrN8\n1cXS4n1QjFCZSxcJSbwbUkwjmw8kRTHFxDyrbTqqpNDPDuIxiX+7ZBbHT/K1HqtUJywA8ZJnBXH9\ntbGhiRyfOA2A/ny1sCNZCiCBWaE5O/uTagJFlogPC3J9dt/zvrAqmyT0Nuxio/C31iB/M9sqtErX\nv+1sb29OolllpjQdw+XT30NCrqO8+RQAzjhFVC5D1T2N2y6lMTPtKI0DyM3i25Ad8rzm1I+RUJKY\nvULI2Ws6Firn2zH6RzPJnOe0GXgXnb5ceubxpBMqaOJdcZMDlZx3Ttt6EqUX345VrAuR97hRTh4B\nh/iS8RjHNYtqeHkjXxWVPaV9lD+8DQNCaFUMMGOeddF2+iCpGsm44rXxrxefQjKhcsGkc0UDDb1M\n7KzHkp3+OO/8xLiwcClt+xjVmvb6m5QTzD6uQ5i9nW1j2tJV1gNVkZkxfpTTB92L1bDNGLppM731\nOD520j8TV0Yu9XsoEZH3UQg3pWFDuka6wVcIsAp+YO5yD9eXqCgyECAb+cDZygyrOjBDMzVRLEEO\ntx3KmuWQd8kh77FO3WOXICqXihkBn7duGV6gkrf8JlS9yaZsB5f4iD4G/es9gexfwSUvUqJA6rRH\nWbTjL962XLmCvKtyh9fI8mZaxI9byZ66p9haeDm0T2nuC/2u5VPXKGDbYDlZtSTFEOStSEiaX3fZ\ndLRGzfKfeX+5j8TMZUiq4S83CvrNAwFKdllMikrM94drugmqjiTbNMQbsEsppHiJTD6Q7U6y6Cn2\nMrZuNC1JQSZBa4LWuB3bkjB6JmBbigjGGhSknXeimG1TYczwOYK8Ee9lMmVWvXMtsrOkKFYmGxDS\nipYg77H1Yzix/QSM2BAtHWUkSfKIxb1HW0tCrOwJGG4msobisXTUif5nQwF8jvbs+DhtUw2Rr/ve\nJeQE8ZiCVaynM93OjuwuQd6ShSTbWIZ/vhCgrND5Vq4Fu9jgkYvSIN7rlqYYNjZJ1dHsFckjLtcq\nIyl6IChNwth/DBAonOFcI6UmUWQJa0jYyPP0e+MNYPaOp0FtAFsKPUM1bnrnx2Iylkve5QyWbVGS\nRTu2HgdkQdJObAKAGjP8sUMQn/usl+xZ7uVkd3FSu59tJnHcS6RmPYuk6khmjNaGROBaQEwnHlMo\nODEPKUX0rTMlNN5sOYuqytiyLtwWtqC5MbFjkWwFdew28vWbkRQD25KIqWIcG2PNSIkCclOvqG8e\n8Hm7z8H1a8tNfdhjnRUThloVVHskEJH3UYxaUeFBM26tJQth8hbHFsoOwcmSZ+6CEaIxAyjVIJ4l\ne5fzu81/Iu6UKHQTHlQm+we8oLr6VIzm+rgXqeznwq4OWOst9PmEqRiA7ZVelOoyJE56GjsogDj3\nIAX81K7mDWHylOurE9G4ZnMXlZp0Lf+pYfpJNoaMAye3qWV216WiiLB2J+eA5m1LgQxteYe8bf8e\n9pZ3ICkmetfxGN2TgLDmHXz+linGRFZNz/pSMjSSJz0FQL1aj1VOIUli+ZUXSZ7MY9kWY+pGe+lb\ng9HphprHLtWBkRBaqwQ9Q4Jsi6azpGr/MSQK47wJXU6UeEL/XxLTnwtZB1plx7edyntCK0DRsa40\nJxqZ2ijiMFIteeceAuRddDRvyRcw3ICidCzFqAZB3iUr8Jwd8rZ0550xYk6kcfC9EwlyknGFsm7S\nnmwjbxTIlYre+YYhuwMi/ne/rQpSsDVBCLGJGyFWoqlRXNclJUWWQBcEVrByoh+q7hEj+OlTvWVy\nAdO+KsvYWgoZGSvmm91BmHx1A5JyGjk9jNK2l8TMZ5AdQSKlJokpMmbJ1byz/HHrwxjNwuftWg2E\ni8dGaXfW5scCPnkH7rNetm8Fd738W4KYN2Yu7516iX8/ySFQdFRJCEiiLdcXrpGIyRSMIkkl4WXO\na3LqqWfKQyKHhaO5e5kiTJVYYTSSbLFOe1ospzNVYoo4/8KxFyFJoI7ewdi2tDf/ueMcU2XqnMC7\n2Litfl/N2EEVHDrUiMj7KEN/IUPylCeQW/bXXCccJCOj1gtVg7xdYrVtO+QTr6V5u/5qqC7WAHjL\nJJTGAZANr22fvG1PA3LN5om4Qkdziv6hEoZp+ZHyjoYeLIPZEyBeSRZtuZp3/NhVyE7mLjcgxtMw\nAm0EyTtoqQhOhi4KevgeNcMK+fprJWUIErxsB9s8sLAFIg7AUHJYpTS25ZyrGCTiCrIsYztadHnT\nqZ42plt+H12t08o3CpO1c76LEHk7yT0kxcS0bAzTIqsNeoFCti152rmUKLK/wmffnGyixZkw3XSu\nSBaWpHt+WDdCtzebc8bL0byNGLph0eFoS7HxIpuZXJ8N9bfDnoptg9wUWAoGDNvid0uiBVsT10qk\nxHlFowSWjLZtltBuXXOrI2C4qTjr4knGNAvhQx29E6VjlzceAGXNyXtQSiPJlne+uz+pCGIp6xYt\nSeH3HigP+uTtnO8SnEsGleZYs38Mck6YY+Vkgfp6cZ6vecvYegJsWJ9ZizJqJ5IEiuW7GWwthW0H\nnoOsE1fiKLLiLAGVSMuNSIkCx09o8oPLzBjDeY3ZTWcgqQbxY1cj1w1jJ7NOH5IidqMk+rJjaCeP\n7XrSfxCGK4CIMY5PWRsao+A35Uac10JKTXJK58zQNkm2ScpJLzmPa2lQO/YwNGoJBb1ISvXT5SbV\nBEklSaacJaZIQrM2VRJxcb5hWpT2jQ9dwzZjnvvw2JaJIj4hVqalIeG9h+51Fdm3HoTa0OO159rD\njIi8jzI8vnOZSBRw3Es10xAGyeBHf1hbtT84eWu2+LusmTy8fKfIchUMOlGq28+X/PaD5kkXwQpS\nUsqv4OOlHJRsQbrgVeJJxAR52zbc9sd1dGcdE6ZD8oWSwc/+tI7t+4ZYvlVIvHaAmAbcDGSBicI1\nobkfYFk3uOuRDdz57OPszPjpX3/1WKAkZg0ff7DYA8AdD7/MMxt2+PsDWt7zG3q4/+ltXoY1gBUv\nB7JG1bBkDBULWJbN/zz0Mt/9zUrW7u0CyRZBOs49SorQvFVZwpZ1UnIdVqbTm1SCfmQ3iMc2Yz75\nB4Uwl0D2noDpaeZim6abFAMa6JT6qb5Glyiw28ls5b5DxbzM48sdM6yreTt+U9fE6fZxy0AXS9fu\n8zVcI0ZXT47lS2JYxXqkej/gUJIgZTfz2TmfADOOlWsRVhGnbbmhnyFpPzNaj6cp0YBWEtey4jl+\nt/lPDJYzKHoDZt94QAqR992PbWbzfiG8NSTqGN/a4l03Ptl5FxzXgusxsYvChy4lnVgKR4BKqkkS\nMZmybnoWiKyWDWhsYdLxfMoBUhjVmgYkFIe8pXiJtMNHSVX0W5Yc06/zPONOXeykEqg1YMvCwuCQ\ntyXrYc0dyA6oSDGd3WN+i5zyg62GCjpnj52HbVbTQVJJEFcV8jlQibNreE9ovyuY6Hum+hslv5JY\ncL16LeKb0Xo8lxx7MZIk0ZRoqNqfjqX9zHqm6iXsKSf3M1jOkI6F6y00JxrZm9/PQMOLYi4zVS+W\n4p7Ht1AeaGFy9p3e8dZwi5dpsi6pYhtxJFVHVWTf8uhp3lLoHqxcE+WNp4KexDBtNnVl+PkD6w95\nIaeREJH3UQYvwtGSR9C8fXJdvbWvan9wOYcZINpHV3RhmFY4KrZm1R0/GKyW5h3URCXVr9Lk19AN\nrNX1yFtm1mThk3txU6+fwcohnadW7eXZ9d3c8svneXHnDsDRLBHLSTzLgeFrIp3pDq8PAC9s7OXJ\nrmdZXlgUShm5bmcgM1egb66/2fXLu9jTm+dXj/vJN4JLjH58/1r+9MyOUCYwSxJ/nzdnfDga10Hf\ncJ6+TJElq/exbscgK7YJ4cQq1vtai2z4ZnPZICY596m7+Zb9PnqWF1MNkH+15l3YO9Zr31svr1ue\nyVnfNY2J6SkhzXt3r/PsHRLq2qvRtdsJ7nK1UvcenWdh9ApNRx23lb+8sMeL5G9K1lPWTdZvz2L2\nj64al3pzFJObJqEbFtZwC5IEckoIdW5ynrdPOFvcS05MY7uNjSLeApDxScMn7yKPPt/FcKmIbcP0\n8e00pfzJ2DYVLjhtgvfeubGK9YrQqpV0nqb6uDdeKVVohZpmehaIgVLGF5Zcn7kTeCincsSOXYXa\nIVZttNbVMXl0g9NH8b6NHS15edBd8vVQIfx1NjaGfttaCjlRIjZlFSYaKZf8HfK2HdO7jS0sHDbU\nxVNccf5xjG2vJ2ZWL29SZMVZlSJR6vOjqMubT0HvOs57zu3pZibEnWWkqk5RFXPP5I4OLxVqXQ3N\n+x+mLOSCSW8T/ayxfGx0U6OnOYPvv3aRVivJ2zGdpzYiyRa2odJQkRBmzqQp3t9mpsPLQJlMqCTk\nJHJcFwKPa4Gq4bcHMHomYGU7aUzHMEyLp1ftZdm6/fRnj0x+84i8jzJ4ZGHEvIQQQYQCoCo0ye99\n8kzGjvJfZMM2mDymkUmjGiiUDQzDqliPWi0cZEv+MqNaPu9g6khJMTzy9uoDy7XIW+GMmaOZM80h\nXOe6biajYPUeOT0szLmOyTc4oU0d3eH9PaqCvLsHi6GsX36DZu2/HXPgLv1lnt+/MnxOILCndi7j\ngHnc6d+0Cc186rJpVUdqZjkkhA3oTgnIYr0n8UuqCNBRFAkUHQXXzxj0AYoJzrUE2IbqTToN9YHP\nXNHF0idLEe3bEpYsyLikGb7mbsQo66YnxMjJgle8xBUWTC0WIka3L+75AP/90X9gUuMElPpBhu0e\n9sbEWH78nbO5/v2zAbDyTVXjIluOS8AwPdLxVg441291TNXZrF1V79pSfWuEbz1wBQwDhRjzZo5B\nkiTU3aeKZUKKiT12LRPGiOuVyqKm9LX/cBYA8+c1M3/2uEAwWIJETMEGGmMueQ/6AW+u5u1o7kpr\nN2rbPuQ6IYTc8E9v83IPWI5PefrxKe+7cjVvL6eMHo7GnzVugvf3tz8+jwmdghzV9n0YUtkjb1fz\ndp+Vi3Qsxa1Xn8Ox45qIqTIzx4r2bEOl05jBW8ecDvhLSs0+EX9QrzZgDY5G6fdzPnz742+lNS2+\nydiY7QzYuzm+ZSpffN+5/MvFM4BqzbshVs/ExrAZ28qFBZJxLS2hnPZSxZyUrqHNB2FrqVBFwpnH\ntPD2U8czPjlZXG+oDdW5P1mSOG5MBzYW31/zQz/4L0DeDTFfwDH7x5KMKzTVJzBMS9R4l6Ct6dBU\nX3wlROR9lMElC3dyrUTIh1rxosdUORSJbdg6MVUmnVTRdKeggHJgzXsoQN7lGpp3KFtWgLx9zdvv\nk0fejmTtfaTudR3ydgPzpGQOuW4IK9vmE1egv2ogAdKYOmfpiUMm/dlSyKzuJVEIEHZQcLED2ZTu\nWH936B6DwVmVZnWlfTdKu798zG1TUYTJOwjbktCscMrUYVMEuNmlOo+0pFiZZFxBloVA4+ZxxlKw\nLRkppnnpJstWwIXg3INhB56pE8ErGEEiIdWhSeKZarqFZrmFMWLifdATyCjIyYDP2yFRvaSCGcM2\nFc9c6xKrazaPqTJtyRaQID9KVGhS+o9lcvNEjxRcK0pojB0TsW5YHmkpDYNI8aInILgTaX8mXPEL\nwFQC5K255F1EqsuKwCzbJ8J0cRJG9zEAPLN/GT3SJm98VUVmVLodWZJZ2buGPnmrFxOQiiW9dzet\nOMVB9Kz/jjvjbzlL+jwyQBBZc6LJS0lslcWzzpQy3mqKSh+xvP2tnNJxovd7Zsdx3t8xVWFWy6zQ\n8UmPvMU4G/smc0LsLG99eqXw3ZRwnoNkc4w1jytnvFcc5xatGWrj7JaFvHvM+wHoqCCphrgjPIze\nSVxOcOnUd4b2B8n7golv4wunX0MlyhtOp7R2nve7M91BIjYyTbkCigvTDs95drE+RN5u8NtFoy6l\n+OLbwYyhBoJZ61RxDz3FXuRkwUmyI85RFZn6eB0fPuFyJmbeAbZMQzqGqsjohk1vtkRrQ7KqENTh\nQkTebzDolsEL3atGXPJVMv3JtbbPO1A4voJ8Y6rirSEGQLZEBKVTYSlb0MKm3Rqad1+gwENNzTto\nNlcM8k4ke9Eh76CJ17CdJTfOB5WIK0h1Wc8n7loO3OVZSpvwVZt943yTslJtogY/o5N7P2U9LJgk\nqHP6GPQHB9ZDl/2JqXISDd5DZWrP+JS1xKesCbTpkLcsoVNhTrNUdFsLZR3TLFdzjnvFFKR4mURM\n8QOeLH+JkK3HQfXJW7c1Z3mM4pnNTcIJQkLEJTWiSQWQRIpUL3LdUB3BSyItNSKlh4jPWYSUzHkC\nUbkozKlWoQEplUNp24OccM+Pe/fdFHdIIZHHzHQwpnwasiT7BXMMv7b4jBZhnVBNJ5LesDxBQB29\nk+QpTwrLhy15/s7eTMl7Z+aOOhWAMaXT/HE2Y9iGitLc65fRDFilEjEFK+dr/xa+5qwqMnElzsJj\nzmNYy/F88REUZy11OuYHU6WoR5UUitKQJxDalkIqoYAR9zK9uVAlX5sDMHSFhBJnsJxlq5NCdErT\nJIKQyg28a8qF3u+JjX5lrrgqc+boed56cPDJ0hUQsFSmpWbzDqeNyY1+8hcARXIEViksCfnr6yVa\n9anETfE8O5rDxOmSN8C5o89hQsPY0H41UBP+H45d6AsLQVgqdqGJy6dcwTWzP8bcUbOrqskFMXfU\n7NDv9x1/ifftg1jn71aQA0g6wlZSjXvvnRog2/p4WJMXFj4xfm5g2+mjT0XVxfuSTsaIKRKGKQJn\nK8fkcCIi7zcYFu94nP9Z93/cv/Xhmvu9IDFbqql5B3OaV5KvLNuUrTC5xlWZdDIG2BhNO/xoVaiK\nNpcb+1jc/cfqvgQQIrMKn7c6ertXHxfAQiz18j5OtRSuUSyHydsNGDKHW3yTskNoqiJj4pN3a7KV\nhJIIpYN1NSZzqIVOyzFhBwQc19xp5ZrQd87wtlea+4Jthgs01Fiap7h542VPsDGHm/nQ8R/ANhVM\nWw/lHNesssiFbckhzTsekz1BwF0/DIARR1I16p01/yaaFyTkBqwFyRtVJyb5ZFkvi0lIbhhk7eAa\nj7xtM0beqaLVoDoR5bKN2tnlkVCx4JiF841IEsSPXYM6QQRTeZYRSQpN0ma2jQ4nKU88oFG17lnI\nl97yWT4y/QOUN84hVRJJRXTDCsUyiPHQUOwEsiRjWlaoZOqxzZP477d/i3GcGDonKIyBsxzPQTyu\nYA2OQt80h45Um3+Q5QsYFx9zPh+Y8T5/V66JZEz1a0obNu2pdqz4MI0Nzn2ZKumEeBZuJjcXHzxB\ntOUSgmHatCRb6C8OsGlwK82JJi8ILvhadaTamdQwgYXHnIcs++MXU2WSCVUkxnFwaufJ4hoBzTKm\nysyfcDafnfMJPjDjn0J9mtYqgs7M3rApOzPsv++9mZIXhNpWURmsMemblMc3d3IgjFQO2MVJ7TM4\nrmUKkiQRj/vvu7Z9JlY5yWzlHXz5LZ9leutxofPG1o/m+jmf8n7bxbqQ5u0+r2CKYzVQddHVvF2Y\nw63e30GN2nUDphNqiPzbm4+MyRwi8j6ieOz5Lrp6wqkphwoaDyz1g5y2O8kLdgyJZSuPrOjinsc3\ne8uhvCQNstCUlq3bz12PbOS3T2xhqKCFfd4V5Js3CuGkIgHNW27qJT55HWpnIA96KEDGJhYo4wjV\nAWuWbYfK5EmKEYo2V8eJ7E5mthUz60ySkuV9nHvl1aH2RE1rvw61p7kYcZ/YHD92XVL1lr5JG9/G\njq4yacXPmAR45KdvO9ExHVdq3k7U9daTwIxTeukcZDPBkBZ+ZkENasPgZm596o/c+9ctxBM1stsF\nNG+3kIax5zjmjj0RLAXN0vnLCr82s47uCCaSuE9bglhZWCVc4UMPrO/V40iK5UUoIxu+VcJdR+ya\n6yUTSbZIyP6k26AKv3HsmHX8pe8BhmNOX4yY9+wanWPASTiiashWjGLJoqUhUdNnbet+bEWQvO1S\n2iPvYKnaJI2MruskEVOxsh1YTlSxHtC8XcjJAoolnv/9T2/HtGxithCwOlLtSJJUpa25goxVrMPM\ntnGccoa3TxwrYWY7mBYkA0vxtFZJkpg35jRa4+K9NbonElMV7zovbupF1uqRFJN0Y9k737VqeX57\nQNtyMjNahb/YNWnbNqT0jnwAACAASURBVExrOZaSWaZgFJnaPLlm8Q5FVvjc3E/xrikLKrZLwrxs\nJLC1BAoqJ7YLAVRRwiQPMLlpkhfU6eLE9hOIbTsXfdf00PZgVbvebJGCM1e11CdCxzUHyLsj3Uot\n/MeZX+Rrb72h5r4g4oHnlwz8bfZOoLzqbYxPTmZUXW0BIRiBbpdTNc3mSlCgCZJ3haBuF/0I+CDJ\nu0pJXSoW2t4RkfffH/b05vj1Y5u56X+eC23/5cMb+MNT2/ijk6fc9dmokkJfpshv/rKZxc+JZTa6\nqfumV4e871y0kcdf3MOi5bt4fkNPyOddGdzh1om2yk6QkWx6Pu9ahelD/uBE0VtD7aLSbL5++wAl\no+RPtorOcN5J0lLWQRITsbb5VH8Nsmx6H+cwvdiWRGnVOZhZ5+OXA+StaiKBhy37EbxOn+rTMXRL\nx9bjFLJJfvC71aTUdCi5hr+EJ4ahy1X36Js7nWIMRh2y1iisCZIFWMSnP4fS0uMtWQHYYDzDw8/u\n8pa+ubBtKeTzdjNC2UYMWZbEGnDZ4PHnffK2EMk3Jo6qByTQ48LnHTCbG5r/2bpaaSJtihSwiu6T\ntvMcOjqcNe+uoBPQLprjzc44OlYBxe+jm9K0PuaTr6Tqjt88TqFk0NaYrE3eAW3ZM5sjfPluLedY\nYFJWHRLz/LPOulndsJDNMEkAyGaSkmbw4DIh7F7UcQVXTv8nprUI7TFoKhX37rgjSmm0jXOZnvTN\n6u77Z9swJu2n6cSWQxM7wPunXIm2YwZm/1hiqkzKuc79T29n5y4np3Z6nTjdUkgnVc49ZayXZAXC\na59PmSpMvP9w5jGcMdrv07njzwx0vur2PbQ42cckSfJIpLT2TK4c93FPuw1mF4yrI5ugAS47ay7Y\nMmec4I/DhXP9wLjeTNEjrvGdgqynTxTvUFOAvNuStcm7OdHkrYmvhfEd4t0MCl+1zOYH8oMDTJFP\nQ983GZBFeteKtoKat+dWIOxDNzPtmAP+OAQ17wWnC5fD204ZGyJvNxvckUBUVewIoViuvfZv/4CY\nLF3Tn0veiqzSE8gnPVzQ6Q/UL0a2KGtmyHReLBuUpaDmHSbkYUeDtMtpSJRANompCnXJcO7lifHj\n2aVtCpO/G6S07xhGG7Pon/DnqoC1fFlDUkyxbjemISkG/UMlLMumZJSRZJvpbZP5+GfO56tPbKef\nHpAt74PSEZnF7HLaXx8qW77ZPKb564fLKSRkSAhLREdTim5LCwWaqXZSRKzLplgj6yWmUCmXbVER\nyCHsKWMbSU2qZ1sOT7CoS6pCY0oBqoYkWSL5DAifbkX0uhtjEJfjDL94BvETlgc0b9kb/1s+IqKX\nZTuGpYhc4u4MLSkGthHnuHHNfPby2Vz/2FKkZE5MHE77Wjkwm7sJJGI6LQ0xioqF5Wb0slQScoJU\nyuCmj8zllj8+AAjTopsfqjXRChUp6t1o9JyjeadUn4ileAlUDb2QwrJt0kmVn3ziHfz48TgbSi+i\nNGRoUBspBsgqpHlrqZqatxfxK0tIkh+kqBsWsRqEI5kJT7iYPKaBK+efSl+fbyFprwim0ndNI3Hc\nS+h7BbknAqbYoJY3OqTNSSGTKMC4pk7MHuGLjqkyHQHTsV1hGscU39aHFkyjdetuFu0Sgkaw1vak\n0Q386NpzhGUFeNv4M2lPtYX93QcoV/Ctj83zAh49Td2Ih4g0SE6V91OJd501hVMmt4a01ffNn8q7\nz5rMt3+9kr39ec+d0tqQ4NZrziYVF8cGY0Nqrek+GHz5I3PRdCtErkGzubftAH5wgOPUuazrEgpR\nkLxdn/dImncwT4W2KRA3QVggPPeUsZw+YxTppMpTq/wA1WT8yFFqpHkfIVg1UpUCVaYxw3KLhMih\nYhCFkkFf0c/JLSt+SktX8ivr1oE1b6fghuf/U0zH5+1XParPT+WczvnO/mCqVD/pQn3MCc6p8Hkv\nyQo/va0lhd9WFVWSBoZLXtnIpkQ9qiLTXu8kvohp3sdZtovexGa7NZklV/O2QdWwveAmmTq5Ednx\ng7c3J0WQn+l/1JYernYkMi4JE2nZ4V3h47dJxhUM1zfsCACphIrlraUuh0zwthFnsiYKIXjJLZzx\nPnPs6SiWWKftEroiS2TKQ0hIjKpvce6gVhIVE0yVeFymPhUThUEUC1s2sJVgoJjTDyeoTZcLtLW4\n5nKfHBpiDWS1Idqakshp8fyntvqaVGstDckQZnt3kp7deiqzW+eIcUgNI8m2l9WsLqkSjym0SZPR\nd8yEgQl8aMpVBNXFUGCSLdf0eYeIXJFFwiBElbRa0buSmfDM+lPGNFV9R5WBQ9bgaN7ffg22YyUI\nam5BIvdWKQT6EkRdYAKPq3LIxxn0N4MTsJZUkSSJ9nTAOmGE1x2nEiqyJCFJEv90/Lt5+4Szqu53\nJKiKHCLaWvcUJKr4K5C3JElV7cnOto7mJLphsddZMphOxqhLxjyiDa7jrmXyPxioilxlNamteR+Y\nvINCyiuZzYPHnth+AnVqmium/2NVm3WBQlCSJHn9DL67ifiRo9SIvI8QauUZr7Xd07wlhd6MT475\nUrXm7aIuJV4iTTdDRSrCmrfN5sw28Zdjcg6bzZ3lN6UpXuIKKWg293Ihq6QSKkk1GdK8dctgW2GD\nc7AFZswj/L5Myat85Urnk5pEQJJclyURU9AtQ0RKuyZ3JxmDu9YbVRdm4YD/s0FpEfV3FZ2OppQg\n74DmrZXcQLhAZivHZFly5CK1bT9Kx26x3MPSPHJXFYlEXMEoOwJATAv5usHGGhjDhORkp9604Res\nUBNiQgsUtFBkiaw2RH28zsvFbLqme1dIkiyRWMJUfFOuQ85lCuiKIF+9GCAMxyeXNftobnL8pwGz\nbGO8kbxeYG+xy8vHPGOUr9U1JeqIyeIeZdstpOFkbnPMo+lEgg/NfC9WOemZ122nIlk66QpbNnax\nAXXvKbSkwlHESSXBuPix6LunoioyTfV+JLoLNaAdKrLkFXoQmncN8jYSnvm2crKH2r7HukCyjrBZ\n1m+/MR7O8hXsl/gd9h8Hr2NraYrPBXzRAZ93Q9zXhG0zTN6viFfBg0HNVKkIbHu1cO91pxO3U1dJ\nskqck9pnVvnjXytqEXWlUHWg/alEtQl+pIC1hng93z7nZs4c+5aqNmu9ZxAm/1cSKg4lIvI+QhiB\nu6vglsNUZYW+rK9590gb+O2m+/0DA8TqlgYtaWaIUIOat9zcy7J9ItLbKjnLpGJlYorsmM2dNddy\niua0Y/JSqoO9MEVO6sZ4AwOlQTQnA1le9zOvGb0TsE0VJSau35sp8v/Yu/P4qMqzf/yfs81MJpls\nkAAJ+yabICgo4i5Qt69WWxUXcKlaRVu1daFUpbUPuFT9Wbva1trqQ12hllddeLpp1YLWlcUVtAjI\nkkD2zHaW3x9nmXMmM5mQZCYZ5vP+h8xkZnLmJMx1rvu+7uuOWevL7eA9uXqMeVyl+/CrDx7Gqzv+\nbZ6npJ7Y/knrzKBmNUZxFy/ZhVRCoB0DyvxQDc0zbG533rKDrjkkbZ6rcLsrWJTvRWNwM3a173Iy\nd0kU4VckaBFX8PZUrsdR3xSBz96yUo45vxO/5IMkCOb6Z8mcKxdFc7cjuwMUAGiqfYFiPq+4OLGk\nx+nnbL3fiNaGmGiNnESCieYeVrOaFr0egZB1fK7gXR4wA+nKj58xHx8NYGBxYs4x4JfNoXMAJdoQ\nCLGg01TEzmz9PslsEqO5A5V5UWF/gNt75IiC0GGeWBAEnFJ9DtQvx6KqPODMwbqzMzkp8/YOm4uY\nV3YhYlumIbLpaMS3j4PSNNK5uEgOIgBQXtJx7tGdOaWbUxUEAQtGXYzohzM7HFcyRRZR2mFnP8HZ\nIcuI+Z2LG89FQYoe+r3NfUHiHjbvjeAdjWnmErqkQCUIAr459RKcMvLkbv+MVFLOb2e4oHFfdLmH\nsv0phs2TL9DSCaYY4QAS9RoAg/dByT1s/t6n9dANsxfuftd2llu/bEJMtTJcw0BdY9jJAPeXJ/aA\n1ttLrAIq8zXNDy8De/WtqI/sT/xQV/C1h5dlQYbeWG0uMSpuhqJ4h819QgClwSLo4SDEUKPzGu7M\ne2d9G0q1oYjpcTyx/jWs/2C3s7etumeY9fqJIri6pjBiVqFdsbWOcnRlrbmOdsBufNGyHau2/MU8\nUCt428PmghL3NOZwF0KVyFaLVCWC0lJ7eU7iP09Ts13oZm1VKCUqsdvaBGCH+SErVdShrug980nW\n/2NZMiuW7QsdqWq7N/OW4tjXHIEMa3jWmuMHzEzTzLytD3YljrgRQVxXUe4aQraXfCnDzKYgpSF7\nIwvZmUqwq5TbjVZExCYYutnD2mn5GPfDiPuwH19ik/oP8xQ0JqqI7S077SmX2JbEOmDAXPs/sMgM\n3pE2H9o3HO08xh42d9bhC4lhUfu47OBk/32bc9YdPwztAJuuGtcdJCVR8BSs+WQRgwODoe0fAqO9\nFOquMdBVn7MbXjDFvvbuzMo5Blfm7Q48/qQ51YmV46C3mFXlyRciboospXyvV0y5GPH35gGaL2Xm\nndziMxv8nmJAd/DufnAZ6JqKSHXBlC2ZsuxMz3FPz9gXAuky784Up/g7AwBZTrwWg/dByJ15P7Rq\nA155dyfuXvmOp9HK8sfeRn2zOTcc0+PY1xxFZWkAJQEFQsRd9GP9J7IztiIFUtUObCsyd/s5acDp\nAOBtB2oF4UvGLwIMEXpbGUR/GIZoNfiQzbaZPsmHoF+GVl8LQdQhVe72PB+agoaWKN57y/xD/vN7\nr+PXaz7Axm3m4+zgamgydMEMmvubo4gLVvC2Mm9REJ0Mz3OeUvTrdg9ZuzPvErnEeZ8lQSvwuTJv\nLe7q2Caaeyy7s57wl0M9VePun2suvZGgt1ZgQukkSKFGSFWJZXSxz6bCMIC4nZl7Mm+/uYey9f7E\nYDPaNHsLy0TWO7bYWspTuQcQVZSUJC5A7A+BgUHz8a/sfx5hcb815SGgusIOgmaTFA0xRIw2xHeM\nxfTBiTXqdvAGzPXtRpv5evaccNAvozpoBqq2Fsks7LOCS0u7+Tu3i3wG+BJroO3gbVfXjq01f85h\nYwc6owLuoFhZGoAAYGhVx9854B16lCUhkXlrZuadnDFqmp5YrpNuODPpQzlV4RLQ8QPXPUfaWYGX\nnbFVliay/AGl5haVMsy/02CK4J343XXN6CHm//3Dxg3M8MhERul+T6LYu5k3kH4IORvsn+WTRQyz\nKtyryjpvhuK+6PLMSSuJkbVU3+/KcSSTPXPeuQverDbPESczKWmAPPQTfLjTu7ymTdgLsbzOmeON\najFEYqq5jlY30KaJEACE6meiWbaWFllV1MGA7GzacP74ryKyx9zooXaIjMtPnAXdMPBKXQPW7/3M\nyXy11lKIZXUIi/sQ9I81s1PV3NtWFAV8e958/PKjTzF1qoh3/55ocFLiC6IZifWPdr/o3U1WW0+7\nGMfOOiUVMVWHjihEeCtSJ9YOxsdNiZaR5vPNDz17pACAOd/tt+daXQ1GrOB95IwifN76mfVzXWug\nnUYuGgTr/BiajOKAbA25CmZBmL9jsxnJNSw4o2w2Pmr+wNmJKbLhWHO/agDtLQJQYgV915y3JEWc\nJVRicTNaVfPnuzPvb596PJa/vBX75E8hKDFUlgWxwzpGe8574UmH4uebXI1rrO5XQwYU47yTxiIU\n9OGdLwdgXcM/UOoL4ZRDvo5h1SFcEDYb5OwT/us89fCRI3HWCWbryTsunYn9zebWh9VNZlCwh8JP\nnjEUf39nBzTdQHFAdoYd506agj98bG7KcsnJhyEkDMCU0WbWfszUIRhUUYTRNWaf7GWXzkSFK6hV\nlRfh9kuPwODK1FXInjlvSUQsrsEwDKfaPDljVDU9MSef5kP1vmuPxusbduGZl825fk+Rmiu4JQc0\nd5CXU2Tw9187B63huJN1L7t0JprazIs+ewcrO4ja2Zq7u9hti7xVzJnMnjIYA8sCGF2ToiNZkvsW\nH43G1ljSnHfXC9Y6M6DU3BfdMHIbvAM+GcsunYnykB+KJODLfe2oTXMRaHNfdCkpRlnSFay53X/t\nHLz18V488Tdzu9p0GXqqi4NcYPDOEbswTSzdB6m0Ac2tewAkrh63la6Fuyg3psUQi5vLqEQB2Cuo\nKJaLIDQMhVi1y1xcJOowAAQUGYJsZn2TB0zA3z7ZZ3bv8oedtZhavVWQ5jOvnu250jZhPwRBgCDH\nYaiJhgMTBtdC/FhE3JpntTPvimAJmvfHzbXWmuQUpe1vbzH/muxqcSdwxs0GNIpVze5aQlJRFAK8\nsdvJrGPbJiIweb35GnIMorVlpN6ayFxDivke3t3/Nt7d/7Z5pyvzdobQ5Rj8483vG9EihII+TB0z\nEOs273aCfak+BOMGV2P9f8zzJEuCk50FjUrobaUQi5s9xwgATc0wg7ccd4oI/ZIPoiA4PbvF4ia0\nqOZzy1xz3n6fhOpQGfaFzfqD8jLRXLalJ4bNq0PeavD4DrOJSHFAdrLYE8dNw4mY5nmcX5FQWQpU\nxkc79w0vH4RqK3sqtiqFAWDW4MPx7KsfQ9s/BJNHVngyQ3e2NW5AotBt+qihngsxURBwyPBEtfWI\nwR23dxw5OH3w6ThsbjhD58mZt98nQdONRJerNMOZpUGf50PeHdB8nmHljnP0smQeQ6oP9oqQ31lf\nDQChoA+hoLeRjP1+3BcCK+bcDkkUUaIcWMFa8rntTFmJH2VJ8/2pmrR0hyyJqAz5sa85mnYIOVvc\nf0/2KE9n5DRLwVIWrKW4QAPM3/PIFH/Hydw1BRw2Pwg5w+ZWT+WInmKHKxd7yZdfkcwPJ1GDIvoQ\njWuQkpYYybIASUlsU1jXGIERC6BdS6x7tVtzhvzmB669WUMM7TAMwwrePkSsHbxkUUaFvxx14X0Q\ny/dCHmAPi7u2WlQVZ/mUvduYMyft7GGsojUchVhsZuYhV+FOyrWg9rB7WzmiH1vLk5QoxFAj9EgR\nEHd1B/OluPp27fhlX0BIFXsgKHFIrYOh7jSDn7OUyPp9+EQ/Lp9yEcT95m5DdsEaYFZdq9aOSoYu\neLL7BmsBgFi6z5mXtzd+QDwAI+aHEGzGjlZzyL0maSlSib1LkRxDaYld7Z0YNncXOk0UToTeYI6q\ndDXzce+6lG7tbUD2I/blKECXMLC8KG27R3exXbHcvXW86aQqWItZ65cVyRu8Az4JqmZkHDYHktY4\np8mQUs2P2wVv7u1dD4Q9kuD+PZX5Qx365OdCb2XeQOJiLpeZd3d4Mu8U1eBdybzNx2U+X+6Lg1R/\nS9nC4J0jTsGatYGCs/uTLWlLQ7tHud9ndmkSJA2KYG4Dam9q4ARvSXSGtQNyAPWNYQhqAO1qO1Td\nvD+shiEKIoKKGbTsIdKI0WbuYiQYgKZ4CuiqigagOdYC//h3nPvawq41yZriFGm1WNXmRofMW0Wj\n8hnE4hZUxMd4AkiqYOLeYcoZQg81QJDjHdbRFrn28lWsYUnPPLp1DGKRtca8bZIzn+tklFa2LMIq\nHrP+I8qS4BS6tEfi0PbVmIFb9cFd6mq0l0JvLYNUXg+p2mxp65f8zsWa3h6C6I9g475NKJKLMCxU\n63kPIcVV+e/XneO2P2R8kmvNtpgYdTiQzOeiCeeiSC7C5AET0j4mZm0vW1bs82Qi7vXSgiDg4gnn\n4uyxp3d7HW86SoqlYnbzEZ8ieoJOQJGg6ZmHzYH0WZV7Pa6U4jF2Zt3Y0vlFdjr2h36uM9RUpG4U\nZ6VjX8wV+/v+fXXGezHoyox9nS8VS9aF2J2x8U22MHhnQTiqoqXduyuY0yXMyvTiRlJ3LtVbgOEE\nb8Xa9UtUAV1GLK5Bttbl2vPjih28NRkCBNQ1heEXzMBoN2Zpj4dRJAcgSaK5VCfuh2EIaFWb8ceP\nVgEAtP2DPB9WA4OuTRosEVenOMOqKPcrIiL2Bh3OnLcVOBUVMdnMumvh3bIwOXifNOxYs2DKZjVZ\nEcvMPa6T23C650EXTVqAG2dcA3XXqMTx6e75bxHlYiLrtT+c7WAfMMzXtocYJUl0/qO3RVRA9SH+\n3ymIb0/sya3IImCIiG01h6ztna38kh/2SgB7HXZYi2B8+egOGzKU+q3aASXmbCBiaHLK5TElUuLi\npegAMp+ja2bivuN+2GlbSltxkeL5MEquDp9dMxNzhx/f5Z/dZUnLxlQtfebt90nQNHPY3C4sTCdd\n5uS+v7Pg3dDazeCdIvPuK+5h855edOVL5q2kec+pMu/OCtbc2/Wm09MLou5i8M6C6x96Ddc/9Jrn\nPrt6Fk7w9gZ3Q9CgR4LmGlZRcTbZ8CsSSopkCJKOPfti5iYMgt061GroIgnOhhRtERXhqIZia3/h\npqg51xpWwwhamar5QWj2zd7eth0fNXyKQfIIaPW1GDwgEVA9OyxZhg8yX3dAqd8pShs62Oc0QnEy\nb6dtpwpVtIbsFe/8kXsI8VuHXdlh/99xQ8xWlfb/PfcmAYD3P+DQkhqMLR8FGCnmvGF1RLOqdodV\nlzgZROyzqYhvH4ehhrk8ys4AZVFwisbs4VmtvhbaPnP4fNSQUqdHtxEtcvrFA4Bf9jkdLY32xEhA\nqsy3LJAI3s7fhC55AtL/G30Kjhx8OAJS9pbq2PPcQyqDnkrsQTnaaMHdrEiWBOiG4azEUBTJO2yu\nmHPebRHVHJXqJCBJXVjDW2o1jXFn92NqzIu5IQO6N8ztVyQU+eU++2B3S3Vx0l2DrL+T0mJfhkf2\nrXS/d/v3ka63eTK7F7zYyd9YV9eJ97b+ffmUp+xCG8MwnA+WRPA2/9XQMXhDC0DdNQbV46LYGfkC\nsLbLnDV5IF54C04wcipXreCtSCIgxWFEfE5L1XJ/KRoANESbMArmnLddLFUe8mPP/nZz/tgXRZEc\nwHdnX463yxow3bUcZXz5mMTxqTK+MuJknHTUZLz7aT3iqo6nt5hrzysrJWyPxc1kU1NwzNQhEMsF\n/CeyCZKswfBFYBhCh0zbfXt8xRgIgoAbzp0Gnyxi9/52zJo4CDe/9ifnfert3jluWRKwaOL52NL4\neYcLjbISH5paDRiGGfyrS8px+lEjURr0YfaUwSgOKLj27EPx8z9thLprDIQae7g8kXnbRU32nuTj\nhpbhzGNGYV9TBDPGV+FnqzbA3GFcgN40EKK1I5tn2Nx1wTFj0FQkq7CaqMiDvsCGfebPOe+YyZ6i\no1NGmu1qX3rjC+e+AaW9u2/wLRdMx0dfNGDK6AGIxTV8/YQxB1Qo1VPupZR2sLH/litCfs8oix3I\nW9oT+5ink/yhe+uF0xGNe7OpMTVluPTUCc4GGwBw8hFD4fdJmDHeu/NWV10wd5wzrN/XejN4H35I\nFRbOH4+jJg/utdfMhuRs+vuLDkdTa+Iz1/130dn5GTE4hEtPnYBDhqcfteqrCzQG7yxStcSmCppm\nz3mbHxya4A3eEPREYxIkdtzy+yRnq0l7DbOdeTt7RUsCdCEOQwtiZ521UUdxJT4PA/sjDeZuZLrq\nZN5V5QEzeFsXEjXFQ1CsFOG4ad4sa3jpUAwOVmN3+15EP5yFY4+ag1DQh+Om1WD9B4lK7WDQgICw\ntbm9gLOPHY09cRn/ec8cNheUKBD3IVDi/aAt9lQrm+996hgzCE8YYfX/1v3QxXarUKxjRe+RVYfj\nyCGHdzj35SV+8z+rIQCCgepQGfyKhLlHJPp6H35IFYJ+Ge1R1cmU7Yst93CsnXkfOnoAJo9MVH+7\nq5zj28cDhoihlRVQRNmpcTDCJdAjRThhzHTPHL1znEWJC5KdrealwNETRnV4HODNEFJ1EOuJytIA\njp4yBIBZiX3aUSMyPKN3uTNve5jX3rSnqjzgyYrt77dFVFRXdF44l5xVpbsYOW5ajee2KAgd7jsQ\n44ZmnqLIld4M3pIo4sQZQzM/sI8lz0PbIympZJpKyPR30JWitmzo+zGdg5j7Cl/Tra+tYXNnj2Xz\nljlfavfztpc7WTtuOXt0W8HSZ+/HbDdOkTSn4GxHnVn1XVtqZtANkUa0W/PRRdY+t1X2jkuy+bo1\nJemvom8+4jpEP5wJI1zq3bQ+oCSGyJUwRH8EmpUZK7LobK0nyCoEXxRGzO/Zlxfo2s5DJa3mHLPe\n2DED6uxDKRS06wLMoFDiSz386QzJG4bntuyZ844797l55v00H+LbJmGYPsN6Qet+Q0R0w3E4b/xZ\nKX9+kc/nFA8CZuFdukpud/FVLqtac8Gdedvnede+xI5x7mFz9+890/RBbwaufCX1g6H7XMvlUHa6\nTaeyrfB+qzlkL7sCEsPmguBu2WmxN9+wMm87wxZE1QreiblQAAiIAc9rGFYWb2gydlrBe0SlOV+8\nP9Jo7kcNIGgFVLvNYeyzQzEiNAynjpyb9j0E5IDTKtL9HyIYkJ3g/Z+ItZtYuGPwhq/dXI8eD3To\nhdyV4F0ePgSRjXMQ+++UDt/r7EMped/idEt07Kvu5P9/dntUIJF5JweCVEU79lW49+VStwwFzGVP\n0Q+OcgJ4mb/jDlm2rhTP5Bv7nRquM2af50TmXeQ59+7fe6bCqT76XO1XCvECJpdD2U5ilmMcNu9F\numF45lI8mbfmLVhzb7cJwargtoKz0SHztrqLWXPefsneDMMM3ppgXQioMnZY2/UNG1AJn6hgQ/1m\nJxjYa4ZDRebws948ELfMPK/L7y8580bS7kh24xdFFlFkWM1gfFZns5i/Q1WwLMo4YegcDAqmn1eU\nZbFDoZqts0KT5Cvv9MHb/Dd52FwUEsHbnitLfs1Uy4DsD8p0u8h1PE4RRqzIXG5WuRdiJzsu2Mv4\netJoo78RBAGGYSRl3lbw3tcOvyIhFFSSNjFxZ96dz3l39fdwMGPwzi7nsz3HDp5PgT62e387rrjn\nn/j7267+1/HEsqrkgjVB6ph5G9awuW7vNuWLwO9zZ97mtZZTdWy9hu7KvJvaYigt9iHgk+GTzCD9\nft0mAMCospEA9OeBHQAAIABJREFUgPJQ9ypFk5tcOHtuW/RwCLIkmPv/Wpm37rOat8T9Kfv+njv+\nLBw39Oi0P7OzZRzJnbHcyoq9c8IBOXWBlz13XGQdWyITTPS/brcadSRn+ikzbyl1Jp+JfcEW19MX\nOdmvOXxQ560h84m9JMtd4S675rQHlgc6jES4f++ZMu+DbXqhOwrxHOTygqWvLqaZefeStz7aCwBY\n+ddPnPvcm444QyuiO/M2AAiJPautYFgSrwX8m6AM+wQ++WRnztvOvINyEIh3zLxrKspR7h/gVMi2\nurbpBIDRZWYR0qSRlTh99ghMH9e1Stprzz4UX9a3ej4EKkJ+nHroNLypbcWYkrF4+8P9MNpKofgT\nFfHmkjd7NzJft1oHdjZ3lep7t1wwHe9vrcfXjh8DUQT+ZT9WTP2zrz17Cl5Y/wVOn20VaLlesqqi\nCKOGhPD5LnP0IPkDIdV8qzNsbkVaWRJxwdxxad8DAFx51hSsa9iBrZFdnuHjZGcdMwqqpuOsY1IX\ntOWj75w/Df/3n+04+fBEEdSx02rQ0h6Hbhg4ekqiHuPCueMQ8MnY+mWip26m4D24MohTjhyOKaMq\nO33cwcyvSDhzzshO29MebIr8svmeh6R/zxfPH9/lTUk6M2N8FU6cXotjpw3p8WsdCAbvA7S18b+o\nC9fjqCHezQXswCYEWiGW1UPbMwLRlJm3GagF0XA2FnGG0q3g7YsOQoV/KBqKdwBSvMOcd5EcgBAX\nnNakLaq5DehXpo/F7JpEj+sLJ3wNL3z+NzRGm5znAeaQ8NeOTywDy+TwQ6pw+CEdA/3Xjp6Mq6uO\nwsdb67B+rbkft/sqNCgH0BSzh/SVbgZvMem22XMaSD1sPmFEhVOpfv5J4/Avc5dMKGLq4dXqiiAu\nPTWx/lpAYthbFARcd85UfPfnr6c8lmCKLlPJAf74aTU4cXpth8e5nXncGEzYfiZ+uWEfFhxydtrH\nBQMyFn7lkLTfz0dDBhTjklO869/H1pbh21/vuKzOXimwbXeLc1+mYXNBEHDeiWN74Ujz21ePHZ35\nQQeZTO/5pF6qmpclsU/+XzJ4H6AH3vkFAGDW4Bmebln2XHdgqtmcJdJa4Q3emrdgDYCZfetyIhu3\nhs1jcQ2KHgREQBOirszbqjZXJBTFi9BqBe9P2z+CKIiYPND7ITin5kjMqTkSb+95DwNTNFzpLe4P\nUPeSnqASRFMssZtXd7bLSw6YPlmCqplDy501TrAdWzsbr+5ch9HWlEEmiepz89+yksQUQ3Kmn7pg\nzTts3tWGVhWBciyddWPXHlzg3Bdt7o0/iAoJ//K7SdVVZ04ZAJKnWARRSxo2TypYgznsbcQDEKwm\nJPYcciSmQdB9ZvAWI4iq1lIxe523LCIoF6FNaoXgb8OeyJeYWDnes4mF2+GDDuvRe83Ep4hmP2rd\n8GTe7mpyo7uZd9J8kt8nOXPQXWn1eN74s/DVMachIHdvXbS3mYP3WFIOm9tz3vbwd+FNN2ad5ClY\n40cYFSYWrHWTmlRY1CGQGELSsLm9zjuRedubeiSGzc3gFo1rEKyGJHEjirC1TttemuWTRQSVIkCO\nQRpgNvaYOWh6z99UNwmC4HyIuueQPOuVXZttHIjkbPdAd0USBfGAAnfyum+3mKp5bqfaijIx523/\nfEbv3uYtWOvfG2QQZQuDdzfFda3zBwg6/vi3T7Hps30AXJm36FoTaFecJw2bb9vdgi92mvPcUSOM\ntri53tVu0qLIEkqUIATRgFS1A7IgY2rV5J6/qR6wP0QVxTtsbjNUxbOTU1clD5tne79cZ847xfda\n2uOe26kL1rpXbU5dx8ybiMG725Izb7ufucMKyA88/T6AdMPmKgZVFGFghbWHtWvplb0dZtQIO3tx\n25m3IotOYBT9EQwtHppoitJHjpo0CANKAzh8fLVzX1BJtAOVoXSrjaA7eI8cHMKF88b37EAzSZEo\nf+/iGZgwvByzJ3v34lZkEbMmVnsKopKHzZl4976JIyowqDKISSMrUFHau21iifIFL1u7SdW9WVgs\nufuVNY9tf3inK1i7Y9FMPP3uK3izHYAuosgvIRzVnK0129VE8LaboiiyiGI90XQk5Ov7db9nHjMK\nZyYtYXIPm/vl7q0td+/zfPslR2R9swdnnbfr1zRuaDluuXBGx8cKAq4+y+z89vQ/twBwVZsbicdQ\n7xo3tBx3XXVUXx8GUZ9i5t1NquEdNjdbV7rms63MuzJkZsTJvc0Bs1GLTxFRVGT9GgzRqdy2s+zW\nWBva4+3wiT5nWN0niyj3JdYvhtIUqvU1d8FadyrNgUTBmiSaLUaz3Xwh0XGte+Pe9pJBnfVqRJRF\nWc28V6xYgffffx+CIGDp0qWYOjWxdnPlypVYs2YNRFHElClT8P3vfz+bh9LrkofN46ruZNsAnK8H\nWMN6qea8BVmFJIoIBqzgrZutIOubIs6weVu8De1qGAEpALs1hSKLKJMSwbs0zaYbfc09593duWq7\nOMkO2tnecEBI7pd6gOSkJi3MvIkoG7KWeb/55pvYtm0bnnrqKSxfvhzLly93vtfa2opHHnkEK1eu\nxBNPPIGtW7fivffey9ahZEVyG8u4qnn7lVvBu9Rqv6m72qMaqnnNJPniePHzv6MdjQDMOe9ia39i\nSfdBgIDWeBva42FnO0/ALFgr8yeCd1mgf3ZOch9z8qYkXWVn3nZGm+3t9xLD5t2M3slLBhm7iSgL\nsvZJuG7dOsyda+5WNWbMGDQ1NaG11exzrSgKFEVBe3s7VFVFOBxGWVn6/Vb7Ql1jGI+t/djZDjJZ\nqszbvVOY0/LUCgLujUnsrFoYsAN/+XwtXtn5uvVY0Wk6IUkiipUgGqNNiGgRTxaryKI3ePeDOe9U\nZDExsNPtzFtK7K8N5K5Pc7eLxa0n9tU2gURUGLI2bF5fX4/JkxPLlyorK1FXV4eSkhL4/X5ce+21\nmDt3Lvx+P04//XSMGtV5v+aKiiBkuXeXCVVVpZ8rXrHyHWzZ3oiyUABXnNVxO8rikOJ5viCJiXXb\ngJN5y4qEqqoQJFkCYEAQAD3uAwLtHY+nrBhl1hy5IokYXl6DD+o+BQBUliSC9eDqEAJFiYA9fNAg\nVA3su3nvdOfRCNYC7wB6OIjSEn+n5zudygpzSkCWRef5AZ+E8cMruvV6mVx46kT86JE3cN68Q7r1\n+iWhAKqqQrj6nKn45aoNOHXO6C69TjbeS6HhOewdPI89l4tzmLNqc/cwZGtrKx5++GG89NJLKCkp\nwSWXXIKPPvoIEyZMSPv8hoaOwa4nqqpCqKtrSfv9fY1h69/2lI/b19CCOiVxf2tbzDNsPn54CB/s\nBMLhOOrqWtAeiSWK1TQZhi4msnPLVacfin+tM3+uKAoY5B+ED2AGb9lINKNoaQ5DiyZ+dVq72Ol7\nyabOzqMAH04oPh8vvl0HjDO6dYzhtqj1WnCe/7MbjoMgICvveVRVMX57y4kQRaFbr9/cHEZdXQtm\njhuIw7v4Opn+FikznsPewfPYc719DtNdCGRt2Ly6uhr19fXO7b1796KqytzcYuvWrRg2bBgqKyvh\n8/lwxBFHYNOmTdk6lG6xLzbSjdJ2GDbX9KRtPs3MW7NeJ2q0QvBHrBcXALXjdZMsys7wuiyJqA3V\nON9zL7tSJNFTCFWi9M9hcwCo8g0GNB/8Svf+1Ox13u4qc9GqPM+W3hqaL8StGIkoN7IWvOfMmYO1\na9cCADZv3ozq6mqUlJhBpra2Flu3bkUkYgazTZs2YeTIkdk6lG4xUqzTdY8edChYi2uA7B42N7Nq\nOxjvqFqDwNRXrRcSYcQ7NpdQRBmqtaRMEgUMK0kEb/dabrt/+IiQuctSd/t254IdfANK9wZ5ZDm3\nc91ERPkga8PmM2bMwOTJk7FgwQIIgoBly5Zh9erVCIVCmDdvHr7xjW9g0aJFkCQJ06dPxxFHHJH5\nRXPIvdRH1VX88aNVmO3aBlQ1OmbeYlFiqMQQzO87Veae1xZgtJVBLDYff+GEr+HThs9RVTQQmlYH\nwCxYqykZjONqZ0MWZcypmYUnsN45JgD47uGLu70eOVfsgjNfN1qjAole6WKWq8x7C+vUiCgXsjrn\nfdNNN3luu+e0FyxYgAULFmTzx/eIu8nGxvoP8cbut/HG7red76tJvc1jqg6xsinxfGgQBQGaYeCL\nvc3eFzdE6K3lQPUOAImtO4FEm1VZFCAKIs7vZH9nScxun+/eYGfe3a02l6zny3mSeff3iykiOjjk\nRzrTBxKZd+pGG8lz3jEtBiHYAr3NrArXoEIUBei6gR/8YZ33yboIvS310rhZE83+2ccfVpPy+/mm\nImQO6Q8o7V7v9UTm3b+D9xGHmPUcIwf3zzX3RHRwYW/zDARBgF/s2Jc7ntzbXG6EIBjQWiogBJuh\nG6qzx7Wn8xoAASLu/8ZXsOaLCMaVj/Z878hJgzC2tgyVKTZc+NkNxyY6teWJMbVluPvq2RhY1r3g\nbQ+79/fg/c2zJuO8ligGlhVlfjARUQ8xeKfhDJsLqbt6JQ+bq4K5lE2PBiHpkpN5R2Kad/03zD2m\ny0sCWDTp/JQ/e0CaQJevexdXl3c/oLl7m/dnkigycBNRznDYPI3EUjEBmqF3+H6HLUFFa9vOmB/Q\nRWiGBkkUsLehvUPmDZ2nvavsXuH9PfMmIsolRpE03FXDWlKWDXirzQ3DgC5Za7jjPhi6BNWIQxIF\nGAY6ZN727mCUmZ1550vBGhFRLjCKpJEp845riYC8e387oJidwIy4HzBEqIaayBalpODP4N1lSp7M\neRMR5RKjSBruOW/N6Dzz/vmfNkFQYgCAUn8I0BKZNwAIHQrWGIi6yqdIkCUBRT6WZxAR2Ri803A3\nadFTDZu75rwjMRWCEoUiKvjhJbNRO6AUcT0OwT67ycPmev9fn91fyJKI755/GM4/aWxfHwoRUb/B\ndCYDM/NOVbCWCOiabkDyx1DmC6G02I/yYDF2RXRIkpW+JxesaflZNd5XDhle0deHQETUrzDzTkN3\nNWlJOeftWuet6ToMKWoOmQPwS+YabdGa6xaS57w1XjMREVH3MXin47RHFVLPebuGzXUhBggGQtbu\nXgEreAv2RiVi0rC51rHpCxERUVcxeKdhrxRLV7AW01QnOzdEs9K8WDG37fRbu3wJaTNvDpsTEVH3\nMXhnIKYpWPt8dyN+9Zy5B7kmmpXmQTt4S1ZmbQftDpk3h82JiKj7GLy7INWcN0QNb31sbt9pWMG7\nWDaDtzNszsybiIiygME7A90wUg6bu7umGZJZvNZh2FxUARgQ/OGkF2XmTURE3cfgnYFupMm8reBt\nGAYMKXnY3NoRTFQhDdoGsbjZ05iFTVqIiKgnGLwzMAwj5Zy3ORRuQDcMCLKdeZu7StnD5pBUSKX7\nAQCXTlqQk+MlIqKDH4N3BuaweYrMGwAkFf/e+SaU2q0AgGDSnLchqBCKWmDEFVQFB+bkeImI6ODH\n4J2BYaRYKmavAZdUPPnpaufu5DlvTYpADIRhREKQRc5zExFR72DwzkDXOxasybDntL33FyctFYvI\n9eY3wqEO+38TERF1F4N3BobRcT9vUU/MaeuRIud+RTSXgNnD5mFxHwBAiIYwpHgQJMOH+M4xEFiv\nRkREPcDgnYFhGNCT5rwF3cysBUmFEU0Eb8GKyvawuV1ULmh++CQfZukLoe4cl/2DJiKigxqDdwYp\nC9bsJiuSCgjmBPjo8Hzn285SMYtgPV4wmHITEVHPMXin8NH+TyH4zMYqqQrWjLgVjK3gbRgCSvUa\n5/uKKENxFagJujeYExER9QSDd5KWWCt++t5v4J/2CgAz8/7vnibPY1S79kxSIQgGYAiQRG9WHfKF\nnK9FnbuIERFR72HwTtIWbwcAp6hM1XTsaWjzPMauX7MzbxgCxKQzGfKVOF+LOnuZExFR72HwThLT\nY57bcVV35rVtumadNsnsXW4Gb++pLHUFbwHmELoB7+sQERF1B4N3koga8dyOxXVAMAvWoh8dgaJw\nLbR6c37bnXlLSeu/Qkpi2Dz5e0RERD3B4J2kPSl4x1XNybz15gEI7T0aajQAABCUqBO8haQzWepP\nBG8hKXgn3yYiIjoQ7NmZJBz3bt8Zs4bNDQMABLRH44Dqg6EqEAJt1lruFAVrimvOW2SwJiKi3sPM\nO0lY6zhsLgg6YJinqj1ilprr4WIIgTAEUYNhCB0CtGepGDNtIiLqRQzeSbyZt4G4pjtD4wAQjpql\n5kakGIJgQPBFAUPskHlLouR8bX+L5WpERNQbGLyTeDJvwUAsrgGCDlmUMLqmFLo5fg4jXJx4nCFA\nTMquJw+YABgCYtsmdPgeERFRTzB4JwnHXcFbVJ05b1EQ4VcS2bQeDSYel2LYPOQrwYTGi6DtGclh\ncyIi6lUM3knCqmvYXNSdanMR3uAN3fV1ig5rAKwit8SwORERUW9g8E4Sdi0VEyTVWectQoLf5w7Y\n7lPXMfMG4AyxC4zeRETUixi8k3gzbw2abkBwhs1dp8u9Q1iKOW8gEbyd2M2KNSIi6gUM3kncTVoE\n0W5ibgZvn2vY3NATpy7VUjEAmDKyEgBw2NiBnvs5BU5ERD3BJi1JYpqrt7lkB28doiBBkdJn3qnm\nvOfOHIZDhldgWHVJh+8RERF1F4N3kqh7Y5KkzFv2BG9vIE+VeYuCgBGDQx3uJyIi6gkGbxfDMBDX\n4s7txLC5DkmQkoK3O1innvMmIiLKBs55u8R11bttp5TIvCVBhCy5h8q9WXiqYfNkrFcjIqLewODt\nYs93S4JZmCaIGiCqEATAJ/o9mbe7YC3dsHk6zNGJiKgnGLxdolbwDohW9zRRg+Azq89LlNABF6wl\nG1xpvu7omrLeOWAiIipInPN2iVvFakVSEG1aCyCp5sYjAEqVEGQxdcGakWadd7KTZtSiOCBj+riB\nGR9LRESUDoO3i5N5C2aGLEgaBMXMvEt9pZCTsm33110ZNpclEXMOHdJ7B0xERAWJw+YuMavS3O8M\nm6vOsHmZrzT9UrE07VGJiIiyIWPw3rp1ay6Oo1+IWcPmPqMIgJV5W8Pm5f4yyHLP5ryJiIh6Q8bg\n/e1vfxsXXHABVq1ahXA4nOnhec0eNldgBm9zztvMvCuKyrwFa8jc25yIiCgbMs55P//88/jkk0/w\n4osvYuHChZg4cSLOPfdcTJ06NRfHl1N2gxZBV2BoIgRJBXwRGLqIkFKMqBRJ/URD5LA5ERHlTJfm\nvMePH4/rr78eS5YswdatW7F48WJcdNFF+O9//5vlw8stO/OGIQG6DLG4GWKgHdq+wVCUpA5rbhw2\nJyKiHMqYee/cuRN/+tOf8Je//AVjx47F1VdfjWOPPRYbN27EzTffjGeeeSYXx5kT9py3oUkwNAmC\nYt6v7hwHRfL2Nvd0WwOYeRMRUc5kDN4LFy7E17/+dfzhD3/AoEGDnPunTp2aceh8xYoVeP/99yEI\nApYuXep5/K5du/Cd73wH8XgckyZNwp133tmDt9E77A5rhi4CmnlqDEOAEQtAlrztURXZvc5b5Jw3\nERHlTMZh8zVr1mDkyJFO4H7iiSfQ1tYGALj99tvTPu/NN9/Etm3b8NRTT2H58uVYvny55/t33303\nLr/8cjz77LOQJAlffvllT95Hr7CXihmqBEO39u6OKwAESJLgqTZP7rbGYXMiIsqVjMH7e9/7Hurr\n653bkUgEt9xyS8YXXrduHebOnQsAGDNmDJqamtDa2goA0HUdb7/9Nk466SQAwLJly1BTU9OtN9Cb\n7DlvXXNl3roMSTSryd0BO3nZGIfNiYgoVzIG78bGRixatMi5fdlll6G5uTnjC9fX16OiosK5XVlZ\nibq6OgDA/v37UVxcjLvuugsXXHAB7r///u4ce6+z57x1VYSzFEyTnEDtnvNOzrwZvImIKFcyznnH\n43Fs3boVY8aMAQBs2rQJ8Xg8w7M6MgzD8/WePXuwaNEi1NbW4qqrrsLLL7+ME044Ie3zKyqCkGXp\ngH9uZ6qqQp7bwqfmMUqiD7D28jZ0CQFFQlVVCEUlifddFFDgXMIYAgYOKO7weoWiUN93b+I57Dme\nw97B89hzuTiHGYP39773PSxevBgtLS3QNA2VlZW49957M75wdXW1Z7h97969qKqqAgBUVFSgpqYG\nw4cPBwDMnj0bn376aafBu6GhPePPPBBVVSHU1bV47mtpN39GuN2A4Lf28tYlSKKAuroWxOJa4sGu\nixEYApoa2+EvwOQ71XmkA8Nz2HM8h72D57HnevscprsQyDhsPm3aNKxduxbPP/881q5dixdffLFL\nmfecOXOwdu1aAMDmzZtRXV2NkpISAIAsyxg2bJizTnzz5s0YNWpUV99L1tjV5mocTuYNXXIqy93z\n3JJnqRg7rBERUe5kzLxbW1vx5z//GQ0NDQDMYfRVq1bhtdde6/R5M2bMwOTJk7FgwQIIgoBly5Zh\n9erVCIVCmDdvHpYuXYolS5bAMAyMHz/eKV7rS1E9BlmUoaqAPedtqDJ8VtB2B2jJ9bVhCDBARESU\nGxmD9w033ICamhq89tpr+MpXvoLXX38dP/jBD7r04jfddJPn9oQJE5yvR4wYgSeeeOLAjjbL4loc\nPlFBXNUhfjEdyvCPEd5+CJSqjgMU7gI1QTAYvImIKGcyDptHo1HceeedqK2txa233orHHnsML774\nYi6OLeeiWgw+yYeYqkNRy1C291hA9UNJUSgnJbdKNRi+iYgoNzIG73g8jvb2dui6joaGBpSXl2P7\n9u25OLaci2kx+CQz81ZkCbpuBmR3NzWbuymLJAuoCAVydpxERFTYMg6bn3XWWXj66adx7rnn4rTT\nTkNlZSVGjBiRi2PLuZgeQ7lYigZVQzCgQNV0AHDmvN3cwfvsY0alDPBERETZkDF42wVngLmka9++\nfZg4cWLWDyzXDMNATIvDJ/kQ13Qosoj2iAogc+bNGW8iIsqljOmiu7vaoEGDMGnSJCeYH0xUXYUB\nwwzeqg6fLELVzcxbSbEVqOgJ3kRERLmTMfOeOHEifvKTn2D69OlQFMW5f/bs2Vk9sFyLWq1RFVGB\nqhlQZBGaZs15KykK1kT3rmIM30RElDsZg/eHH34IAHjrrbec+wRBOOiCt92gRbY28VZkyZnzTpV5\ne4fN9RwcIRERkSlj8H788cdzcRx9zt4ONBG8RSd4y3IiUFeVB1DXGPEOmzPzJiKiHMoYvC+88MKU\nc9wrV67MygH1leTM2yeLUK1hc9k1RL78yqMQi2t4+p9bnftYsEZERLnUpQ5rtng8jvXr1yMYDGb1\noPqCvZe3CHN+293HXHb1MZcl0dkaVI8UQQyEUSQX5fBIiYio0GUM3rNmzfLcnjNnDq688sqsHVBf\nienmsLmExLC5rUM3Nfs5Hx+B4NCdOPb4g2v+n4iI+reMwTu5m9quXbvw+eefZ+2A+krMybzNU+Ju\nzCKLqZbGGTCixZD3TIFPUlJ8n4iIKDsyBu9LLrnE+VoQBJSUlOC6667L6kH1BSd4GzIA3ZN5y510\nTzv4VrwTEVF/lzF4/+Mf/4Cu6xCtoq14PO5Z732wiOnu4B3zbEYipxk2JyIi6gsZo9LatWuxePFi\n5/ZFF12El156KasH1RfsgjUY5ilxr+2WUgybc3UYERH1lYzB+9FHH8WPf/xj5/bvfvc7PProo1k9\nqL4Qt9Z5Q7fmvBV3wVr6wfGDsVUsERH1bxmDt2EYCIVCzu2SkpKDMmBFtCgAQLCCtzvzdq/ztjHx\nJiKivpJxznvKlCm44YYbMGvWLBiGgVdffRVTpkzJxbHllB287cxbUdzrvDnnTURE/UfG4H3bbbdh\nzZo12LBhAwRBwJlnnolTTjklF8eWU1HVCt6anXm7C9YOvpEGIiLKXxmDdzgchqIouP322wEATzzx\nBMLhMIqLi7N+cLlkZ96GZgZt91KxUNDX4fFVZQEAQG3VwXUeiIio/8s4Hnzrrbeivr7euR2JRHDL\nLbdk9aD6gp1566oZvH2yiOVXHolLTjkEIwaHOjz+lCOH44KTx+HKMybl9DiJiIgyBu/GxkYsWrTI\nuX3ZZZehubk5qwfVFyJaBIqoQNPM24osYsiAYhx/WG3KxyuyhHkzh6XMyomIiLIpY/COx+PYujWx\ng9bGjRsRj8ezelB9IaJFEZD8iKvWHt6ddFUjIiLqSxnnvL/3ve9h8eLFaGlpga7rqKiowL333puL\nY8upqBpFQPYjxuBNRET9XMYINW3aNKxduxarVq3CkiVLUF1djWuuuSYXx5ZTyZm3z9UelYiIqD/J\nmHm/9957WL16NV544QXouo4f/ehHmD9/fi6OLWd0Q0dUi8Ev+xHXmHkTEVH/ljZC/eY3v8Fpp52G\nG2+8EZWVlVi1ahWGDx+O008//aDbmMTuax6Q/IjHzYo1Bm8iIuqv0mbeDz74IMaOHYs77rgDRx11\nFICDt4931FrjHZADaGPmTURE/Vza4P3yyy/jT3/6E5YtWwZd13H22WcflFXmABCx1nj7JbNgTRBS\n7yRGRETUH6RNL6uqqnDVVVdh7dq1WLFiBb744gvs3LkTV199NV555ZVcHmPWOZm3VbDmk6WDdpSB\niIjyX5fGhmfOnIm7774br776Kk444QT8/Oc/z/Zx5VRYjQCAWbCm6hwyJyKifu2AolRJSQkWLFiA\np59+OlvH0ye8mbfG4E1ERP0aoxSA9ngYAFAkB9DcHkdx4OCqpiciooMLgzeAdtUM3qLuRzSmoao8\n0MdHRERElB6DN4D2eDsAIBoxT0dVeVFfHg4REVGnGLwBtFmZd6SNwZuIiPo/Bm8kMu+WVvM2gzcR\nEfVnDN5IzHk3NZnd1TjnTURE/RmDN4C2eDsUUUFLmxm8K0L+Pj4iIiKi9Bi8YQ6bFytBRGPmpiQ+\nhduBEhFR/8XgDXPYPCgXIRLX4FNEiGyNSkRE/VjBB2/d0BFWIwgqRYjFNfiZdRMRUT9X8ME7rEZg\nwECxHEQ6NoGoAAAYmElEQVSUwZuIiPIAg7ddad6so6E5Cr+PwZuIiPq3gg/eMc3co3zL9jYYADNv\nIiLq9wo+eMd1M3gbunkqGLyJiKi/Y/DWVfMLBm8iIsoTDN5W5g3dDNo+peBPCRER9XMFH6ni1pw3\nDPNUBFiwRkRE/RyDtzPnbWfeDN5ERNS/MXhzzpuIiPIMg7cz583gTURE+YHBW/MOmzN4ExFRf5fV\n4L1ixQqcf/75WLBgATZs2JDyMffffz8WLlyYzcPolDNsbhWsiSI3JSEiov4ta8H7zTffxLZt2/DU\nU09h+fLlWL58eYfHbNmyBf/5z3+ydQhdkrxUTNP0PjwaIiKizLIWvNetW4e5c+cCAMaMGYOmpia0\ntrZ6HnP33XfjxhtvzNYhdEksqcOapht9eThEREQZZS1419fXo6KiwrldWVmJuro65/bq1asxa9Ys\n1NbWZusQukR1qs3NzLu4SOnDoyEiIspMztUPMoxERtvY2IjVq1fj0UcfxZ49e7r0/IqKIGS5d4vJ\nqqpCkD63bugi5h85Al89aTwkznsfkKqqUF8fQt7jOew5nsPewfPYc7k4h1kL3tXV1aivr3du7927\nF1VVVQCA9evXY//+/bjooosQi8XwxRdfYMWKFVi6dGna12toaO/V46uqCqGurgXN7ebrGrqEuTNq\nsH9fa4Znkpt9Hqn7eA57juewd/A89lxvn8N0FwJZGzafM2cO1q5dCwDYvHkzqqurUVJSAgA45ZRT\n8MILL+Dpp5/Gz372M0yePLnTwJ1NqqvaXBILfuUcERHlgaxl3jNmzMDkyZOxYMECCIKAZcuWYfXq\n1QiFQpg3b162fuwBi7matMgSh8uJiKj/y+qc90033eS5PWHChA6PGTp0KB5//PFsHkannI1JdAmy\nxMybiIj6v4KPVqquWg1aBBaqERFRXij44B3T4xAMs4qdmTcREeWDgo9WcSt4CwJboxIRUX5g8NZU\nCKw0JyKiPFLwESuuxwFDYqU5ERHlDQZvPW4tEyv4U0FERHmioCOWYRiIaXFAl1lpTkREeaOgg7dq\naDBgwGCDFiIiyiMFHbxjWsz8QpcgcdiciIjyREFHLDt4G5rEYXMiIsobhR28rb7mhsaCNSIiyh8F\nHbHszFtn5k1ERHmkwIM3M28iIso/BR2xYnoi82a1ORER5YvCDt4sWCMiojxU4MHb3stb5FIxIiLK\nGwUdsRLrvGXOeRMRUd4o6IjlLBXTRQ6bExFR3ijs4O3qsMaCNSIiyhcM3gCgSdzPm4iI8kZBR6zE\nsLkERSnoU0FERHmkoCOWe9hcYcEaERHliYKOWFFnqZgEHzNvIiLKEwUdseJWhzWDmTcREeWRgo5Y\nMVfmrchS3x4MERFRFxV08I46c94iFLmgTwUREeWRgo5Yqq5CggRAgI/Bm4iI8kRBRyzVUCEKMgAw\n8yYiorxR0BErrschwpzrZvAmIqJ8UdARK66pruDNgjUiIsoPBR28VUOFYJingJk3ERHli4KOWKqu\nQrAybxasERFRvijoiBXXVQgG57yJiCi/FGzEMgzDzLw5bE5ERHmmYCOWqqvmF8y8iYgozxRsxIpr\nVvDW7cyb1eZERJQfCjd423t5W8PmLFgjIqJ8UbARy868DZ1z3kRElF8KNmLFrMwbmghBACRR6NsD\nIiIi6qKCDd6qlXnrugBFFiEIDN5ERJQfCjZ423t5G5oIRSrY00BERHmoYKOWXbCmaQJ8CivNiYgo\nfxRu8LaHzTWBmTcREeWVgo1acd0O3iIUpWBPAxER5aGCjVpxa85bUwXIYsGeBiIiykMFG7Xcw+ay\nzEpzIiLKH4UbvK2CNV0TITHzJiKiPFKwUcvpbW6IkCVm3kRElD8KN3jbvc11ETKrzYmIKI8UbNSy\nm7TAENkalYiI8krBBm9nP29dhMTMm4iI8kjBRq2Ys6uYBJmZNxER5ZGCDd5x97A5C9aIiCiPyNl8\n8RUrVuD999+HIAhYunQppk6d6nxv/fr1eOCBByCKIkaNGoXly5dDzOGSrYgaNb/QJBasERFRXsla\n1HrzzTexbds2PPXUU1i+fDmWL1/u+f4dd9yBhx56CE8++STa2trw6quvZutQUgqrEQCAocssWCMi\norySteC9bt06zJ07FwAwZswYNDU1obW11fn+6tWrMXjwYABAZWUlGhoasnUoKUXiZvCGJjPzJiKi\nvJK1qFVfX4+KigrndmVlJerq6pzbJSUlAIC9e/fi9ddfx/HHH5+tQ0kprEYhQLCqzZl5ExFR/sjq\nnLebYRgd7tu3bx+uvvpqLFu2zBPoU6moCEKWe2/f7XA8Ap/kRzsElJYEUFUV6rXXLjQ8dz3Hc9hz\nPIe9g+ex53JxDrMWvKurq1FfX+/c3rt3L6qqqpzbra2tuPLKK3HDDTfgmGOOyfh6DQ3tvXp8YTUC\nBQoAIBqNo66upVdfv1BUVYV47nqI57DneA57B89jz/X2OUx3IZC1YfM5c+Zg7dq1AIDNmzejurra\nGSoHgLvvvhuXXHIJjjvuuGwdQqci8QgU0QcALFgjIqK8krXMe8aMGZg8eTIWLFgAQRCwbNkyrF69\nGqFQCMcccwyee+45bNu2Dc8++ywA4IwzzsD555+frcPpIKxGUSGXAgAL1oiIKK9kdc77pptu8tye\nMGGC8/WmTZuy+aM7FddVqLoKQTffPoM3EVHfevnlv+OEE07u0mN/8pP7ce65C1BTU5vlo+q/CjJq\nRa0GLbv2xgBw2JyIqC/t2vUl/va3tV1+/PXXf7egAzeQw2rz/iSimcHb0MzqdS4VIyLqOw88cA8+\n/HAzHn30N9B1HV9+uRO7dn2JBx/8Be66607U1e1FOBzG5ZdfhTlzjsV1112F73znFvzzn39HW1sr\nvvhiG3bu3IFvf/u7mD17jvO6qqpi+fIfdHj+J598hPvvvweiKGDKlGm49trrU95n/5zRo8di1aqn\n0NjYiOnTD8eTT/4v2tvbcd11N+Ldd9/Gyy//HbquY/bsObj11u+ipaUFd955G9ra2lBSUoI77vgf\nXH75Rfj9759AMBjEhg3v4cknV2LFih93+5wVZPCOWsEb9rB5DtuyEhH1Z0//Ywv+89HeXn3NmROq\ncd5JY9N+/4ILFmL16qdx2WVX4pFHHoaqxvGLX/wWDQ37MWvWUTj11DOwc+cO3H77EsyZc6znuXv3\n7sF99z2E9ev/jT//eZUneLe0NKd8/oMP3oebb16KsWPH4Uc/ugO7d+9KeV86W7duwRNPrIbP58O7\n776NX/zitxBFEeeddxauvfabeOKJxzFr1myce+4CPPXUSrzzzls47rgT8dpr/8L8+afgtddewbx5\nX+nROS3I4G33NTc08+0z8yYi6j8mTpwMAAiFSvHhh5uxZs1qCIKI5uamDo+dOvUwAObyZHcXz86e\n/8UX2zB27DgAwO2335n2vnTGjh0Hn89crRQIBHDddVdBkiQ0NjaisbERn3zyEa644hoAwPnnXwQA\nqKmpxW9/+0vMn38K3n33bXzjG1cf+IlxKczgrSU2JQFYsEZEZDvvpLGdZsm5oChmD46//vUlNDc3\n4+c//y2am5txxRULOzxWkhLNu5KbgaV7fqpNsFLdJwiJxE5V1Q7Ht3v3Ljz11Er87ncrEQwGsXDh\nedZrSTAM3fNaY8eOw759+/Dhh5sxatQY+P3+zk9CBgUZtSL2piR25s2CNSKiPiOKIjRN63B/Y2Mj\nhgypgSiKeOWVfyAejx/Q66Z7/siRo7B5s7ni6a677sR///t5yvuKi4uxb5/ZbGzjxvdTvn5FRQWC\nwSA+/vgj7N69G/F4HBMnTsLbb/8HAPDcc6vw4ot/AQCcdNI8PPDAPZg375QDeh+pFGTwthlx88qH\nmTcRUd8ZMWIUPv74Izz00P2e+0844ST8+9+v4vrrr0FRURGqq6vx6KO/6fLrpnv+9dffhJ/97P/D\nNdd8A6FQKUaOHJXyvjPPPAf3338vbr75egwcWNXh9ceNG4+ioiCuueZy/P3v/4ezzjoHP/zhD3Hu\nuRdg06YNuO66q/Dvf7+G448/EQBw8snzsHfvXhx++MyenTAAgpGq6Xg/1Jvt5uJaHNf89hnojdWA\nIeKWC6ZjwojOe6tTamyn2HM8hz3Hc9g7eB57rrNz+Pzza7B79y584xvfPKDXS6Ug57wVSYHeMNi5\nzcybiIiy6Z57/gdffrkTd911X6+8XkEG72SsNiciomy69dbbevX1CjLl1HXvTAEL1oiIKJ8UZPCO\nxr1VjRw2JyKifFKQUSvWIXgz8yYiovxRkME7OfOW2B6ViIjySEFGrWjc2/mGmTcRUd96+eW/H/Bz\n3nvvHTQ07M/C0fR/hRm8Y0mZN+e8iYj6zIFuCWp7/vk1BRu8C3KpWMdhc2beRER9xb0l6PnnX4gV\nK36IlpYWaJqGG264GWPHjsP//u/v8cor/4Qoipgz51hMnDgJr776Mj7//DP8z//ci8GDzd4dfbEN\n6OWXX+VsAxqLReD3F2VlG1A3Bm+w2pyIyLZ6y1/w7t6Nvfqa06sPxTljz0j7ffeWoL///W9x5JFH\n4//9v6/i888/w09+ch8efPAXePLJ/8Vzz70ESZLw3HOrMHPmURg7djy+851bnMAN9M02oOeff6Gz\nDejixVfiZz/7VVa2AXVj8AabtBAR9RcbN25AY2MD1q59AQAQjZobSZ1wwsm44YbFmDfvFMyfn35j\nj77YBrS5uTkn24C6FWTwrgz54ZNF6IYBVTMgCgzeREQAcM7YMzrNkrNNUWTceOPNmDJlquf+m276\nHrZt+y/+8Y+/4lvf+iZ+/es/pHz+wbwNqOfYe+2V8sghwyvw1IrT8fBNJ+DXN5/Q14dDRFTQ3FuC\nTpo0Bf/618sAgM8//wxPPvm/aG1txaOP/gYjRozEZZddiVCoDO3tbSm3Ej2YtwH1nLNefbU8Iksi\nBEHgfDcRUR9zbwn69a+fj507t2Px4itwzz3/g8MOm4GSkhI0NjbgyisX4dvfvhqTJ09BaWkZDjts\nBm677VZ89tlW57X6YhvQ+++/x9kGdOHChVnbBtStILcEBbj1XW/heew5nsOe4znsHTyPPZd8Druz\nDWjy66VSkHPeRERE2dbb24C6MXgTERFlQW9vA+rGCV8iIqI8w+BNRESUZxi8iYiI8gyDNxERUZ5h\n8CYiIsozDN5ERER5hsGbiIgozzB4ExER5Zm8aY9KREREJmbeREREeYbBm4iIKM8weBMREeUZBm8i\nIqI8w+BNRESUZxi8iYiI8kxB7ue9YsUKvP/++xAEAUuXLsXUqVP7+pD6tU8++QSLFy/GpZdeiosv\nvhi7du3CLbfcAk3TUFVVhR//+Mfw+XxYs2YN/vCHP0AURZx33nk499xz+/rQ+417770Xb7/9NlRV\nxTe/+U0ceuihPIcHIBwOY8mSJdi3bx+i0SgWL16MCRMm8Bx2UyQSwRlnnIHFixdj9uzZPI8H4I03\n3sD111+PcePGAQDGjx+PK664Ivfn0Cgwb7zxhnHVVVcZhmEYW7ZsMc4777w+PqL+ra2tzbj44ouN\n2267zXj88ccNwzCMJUuWGC+88IJhGIZx//33GytXrjTa2tqM+fPnG83NzUY4HDZOP/10o6GhoS8P\nvd9Yt26dccUVVxiGYRj79+83jj/+eJ7DA/T8888bv/71rw3DMIwdO3YY8+fP5znsgQceeMA455xz\njFWrVvE8HqD169cb3/rWtzz39cU5LLhh83Xr1mHu3LkAgDFjxqCpqQmtra19fFT9l8/nw29+8xtU\nV1c7973xxhs4+eSTAQAnnngi1q1bh/fffx+HHnooQqEQAoEAZsyYgXfeeaevDrtfmTlzJn7yk58A\nAEpLSxEOh3kOD9Bpp52GK6+8EgCwa9cuDBo0iOewm7Zu3YotW7bghBNOAMD/z72hL85hwQXv+vp6\nVFRUOLcrKytRV1fXh0fUv8myjEAg4LkvHA7D5/MBAAYMGIC6ujrU19ejsrLSeQzPa4IkSQgGgwCA\nZ599FscddxzPYTctWLAAN910E5YuXcpz2E333HMPlixZ4tzmeTxwW7ZswdVXX40LLrgAr7/+ep+c\nw4Kc83Yz2B22R9KdP57Xjv72t7/h2Wefxe9+9zvMnz/fuZ/nsOuefPJJfPjhh7j55ps954fnsGue\ne+45HHbYYRg2bFjK7/M8ZjZy5Ehcd911OPXUU7F9+3YsWrQImqY538/VOSy44F1dXY36+nrn9t69\ne1FVVdWHR5R/gsEgIpEIAoEA9uzZg+rq6pTn9bDDDuvDo+xfXn31VfzqV7/Cb3/7W4RCIZ7DA7Rp\n0yYMGDAAQ4YMwcSJE6FpGoqLi3kOD9DLL7+M7du34+WXX8bu3bvh8/n4t3iABg0ahNNOOw0AMHz4\ncAwcOBAbN27M+TksuGHzOXPmYO3atQCAzZs3o7q6GiUlJX18VPnl6KOPds7h//3f/+HYY4/FtGnT\nsHHjRjQ3N6OtrQ3vvPMOjjjiiD4+0v6hpaUF9957Lx5++GGUl5cD4Dk8UG+99RZ+97vfATCnvtrb\n23kOu+HBBx/EqlWr8PTTT+Pcc8/F4sWLeR4P0Jo1a/DII48AAOrq6rBv3z6cc845OT+HBbmr2H33\n3Ye33noLgiBg2bJlmDBhQl8fUr+1adMm3HPPPdi5cydkWcagQYNw3333YcmSJYhGo6ipqcFdd90F\nRVHw0ksv4ZFHHoEgCLj44otx5pln9vXh9wtPPfUUfvrTn2LUqFHOfXfffTduu+02nsMuikQi+P73\nv49du3YhEonguuuuw5QpU3DrrbfyHHbTT3/6U9TW1uKYY47heTwAra2tuOmmm9Dc3Ix4PI7rrrsO\nEydOzPk5LMjgTURElM8KbticiIgo3zF4ExER5RkGbyIiojzD4E1ERJRnGLyJiIjyTME1aSHKN/fe\ney82btyIaDSKDz74ANOnTwcAfO1rX8NXv/rVLr3Gr3/9a4wfP97pZ53KwoUL8fvf/x6SJPXGYXvs\n2bMHn332GWbPnt3rr01UiLhUjChP7NixAxdeeCH+9a9/9fWhHLA1a9Zg69atuPHGG/v6UIgOCsy8\nifLYT3/6U+zYsQNffvklbr31VkQiEdx3333w+XyIRCJYtmwZJk+ejCVLluDwww/H7Nmzcc011+CY\nY47Bhg0b0NbWhocffhiDBg3CIYccgs2bN+OXv/wlGhsbsXv3bmzbtg1HHnkkbr/9dkSjUdx6663Y\nuXMnBg8eDEmSMGfOHM8exW1tbfjud7+L5uZmqKqKE088EWeccQYefPBBGIaB8vJyXHTRRbjzzjux\nbds2tLW14YwzzsDll1+O1atX469//SsEQcCePXswevRorFixAoqi9OEZJuqfOOdNlOd27NiBxx57\nDFOmTEFjYyN+8IMf4LHHHsOiRYvw8MMPd3j81q1bcc4552DlypWYOHEiXnzxxQ6P+eCDD/DQQw/h\n2WefxerVq9HU1IQ1a9ZAVVU888wzuOOOO/D66693eN6///1vqKqKP/7xj3jyyScRDAZRW1uLs88+\nG2eeeSYuu+wyPPbYY6iursbjjz+OZ555Bs8//zw++ugjAMDGjRv///bu2CW1MIzj+NcONQQRQi3W\nYnBsjDoSBFKNOVaEo0M4REO4HGyrKQin5ob+gDBaoiVyECEipakhWkKkQKFoiERPd5DOzYxLlysX\njvw+4+F5X97tx/PyHh7S6TSHh4eUy2VP3jKI/A/qvEU8bmJiAp/PB8DQ0BC7u7u8vb3x8vLC4OBg\nW73f78c0TQACgQBPT09tNZZlYRgGhmHg9/t5fn7m5uaG6elpAIaHh7Esq23d1NQUe3t7bGxsMDc3\nx8rKCj09rT3CxcUFDw8PXF5eAlCr1bi/v3fXf4xPnZyc5O7uzp2TLCK/KbxFPO7ztbJt22xvbzMz\nM8P5+bk7zOOzrw/Svnv28l2N4zgtQfw1lKE5y/j4+JhiscjZ2RnLy8scHR211PT19bG+vs7CwkLL\n90wmg+M4fzyXiDTp2lyki1QqFUzTpNFocHp6Sq1W69jeY2NjFItFAKrVKldXV201uVyObDaLZVnY\ntk1/fz/VahWfz0e9XgeaXf3HVb3jOOzs7Ljd//X1Na+vr7y/v1MoFBgfH+/Y+UW6iTpvkS6SSCSI\nx+MEAgFWV1exbZuDg4OO7L20tEQ2myUWizE6Oko4HG7r0IPBIKlUiv39fQzDIBKJMDIyQjgcJplM\n0tvby9raGre3t8RiMRqNBvPz8+6o1FAoxObmJqVSCdM0iUQiHTm7SLfRr2Ii8iOPj48UCgWi0SiO\n47C4uMjW1pb73/m/ymQy5PN50ul0R/YT6WbqvEXkRwYGBjg5OXHnE8/OznYsuEXk76jzFhER8Rg9\nWBMREfEYhbeIiIjHKLxFREQ8RuEtIiLiMQpvERERj1F4i4iIeMwvRph4T/csGFUAAAAASUVORK5C\nYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "metadata": { + "id": "HNqUFL4deCsL", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# 4. Case study: building an RNN\n" + ] + }, + { + "metadata": { + "id": "YkC1k4HEQ7rw", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "In this exercise we build and train a model similar to the RNNColorbot model that was used in the main Eager notebook. The model is adapted for converting and training in graph mode." + ] + }, + { + "metadata": { + "id": "7nkPDl5CTCNb", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "To get started, we load the colorbot dataset. The code is identical to that used in the other exercise and its details are unimportant." + ] + }, + { + "metadata": { + "id": "A0uREmVXCQEw", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def parse(line):\n", + " \"\"\"Parses a line from the colors dataset.\n", + " \n", + " Args:\n", + " line: A comma-separated string containing four items:\n", + " color_name, red, green, and blue, representing the name and\n", + " respectively the RGB value of the color, as an integer\n", + " between 0 and 255.\n", + "\n", + " Returns:\n", + " A tuple of three tensors (rgb, chars, length), of shapes: (batch_size, 3),\n", + " (batch_size, max_sequence_length, 256) and respectively (batch_size).\n", + " \"\"\"\n", + " items = tf.string_split([line], \",\").values\n", + " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", + " color_name = items[0]\n", + " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", + " length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n", + " return rgb, chars, length\n", + "\n", + "\n", + "def maybe_download(filename, work_directory, source_url):\n", + " \"\"\"Downloads the data from source url.\"\"\"\n", + " if not tf.gfile.Exists(work_directory):\n", + " tf.gfile.MakeDirs(work_directory)\n", + " filepath = os.path.join(work_directory, filename)\n", + " if not tf.gfile.Exists(filepath):\n", + " temp_file_name, _ = six.moves.urllib.request.urlretrieve(source_url)\n", + " tf.gfile.Copy(temp_file_name, filepath)\n", + " with tf.gfile.GFile(filepath) as f:\n", + " size = f.size()\n", + " print('Successfully downloaded', filename, size, 'bytes.')\n", + " return filepath\n", + "\n", + "\n", + "def load_dataset(data_dir, url, batch_size, training=True):\n", + " \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n", + " path = maybe_download(os.path.basename(url), data_dir, url)\n", + " dataset = tf.data.TextLineDataset(path)\n", + " dataset = dataset.skip(1)\n", + " dataset = dataset.map(parse)\n", + " dataset = dataset.cache()\n", + " dataset = dataset.repeat()\n", + " if training:\n", + " dataset = dataset.shuffle(buffer_size=3000)\n", + " dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None, None], []))\n", + " return dataset\n", + "\n", + "\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "data_dir = \"tmp/rnn/data\"" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "waZ89t3DTUla", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Next, we set up the RNNColobot model, which is very similar to the one we used in the main exercise.\n", + "\n", + "Autograph doesn't fully support classes yet (but it will soon!), so we'll write the model using simple functions." + ] + }, + { + "metadata": { + "id": "9v8AJouiC44V", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def model_components():\n", + " lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", + " lower_cell.build(tf.TensorShape((None, 256)))\n", + " upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " upper_cell.build(tf.TensorShape((None, 256)))\n", + " relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", + " relu_layer.build(tf.TensorShape((None, 128)))\n", + " return lower_cell, upper_cell, relu_layer\n", + "\n", + "\n", + "def rnn_layer(chars, cell, batch_size, training):\n", + " \"\"\"A simple RNN layer.\n", + " \n", + " Args:\n", + " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", + " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", + " \"\"\"\n", + " hidden_outputs = []\n", + " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " state, output = cell.zero_state(batch_size, tf.float32)\n", + " n = tf.shape(chars)[0]\n", + " i = 0\n", + " while i < n:\n", + " ch = chars[i]\n", + " cell_output, (state, output) = cell.call(ch, (state, output))\n", + " hidden_outputs.append(cell_output)\n", + " i += 1\n", + " hidden_outputs = hidden_outputs.stack()\n", + " if training:\n", + " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", + " return hidden_outputs\n", + "\n", + "\n", + "def model(inputs, lower_cell, upper_cell, relu_layer, batch_size, training):\n", + " \"\"\"RNNColorbot model.\n", + " \n", + " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", + " followed by a fully connected layer with ReLU activation.\n", + " \n", + " Args:\n", + " inputs: A tuple (chars, length)\n", + " lower_cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " upper_cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " relu_layer: An object of type tf.layers.Dense\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + " \n", + " Returns:\n", + " A Tensor of shape (batch_size, 3) - the model predictions.\n", + " \"\"\"\n", + " (chars, length) = inputs\n", + " chars_time_major = tf.transpose(chars, [1, 0, 2])\n", + " chars_time_major.set_shape((None, batch_size, 256))\n", + "\n", + " hidden_outputs = rnn_layer(chars_time_major, lower_cell, batch_size, training)\n", + " final_outputs = rnn_layer(hidden_outputs, upper_cell, batch_size, training)\n", + "\n", + " # Grab just the end-of-sequence from each output.\n", + " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " sequence_ends = tf.gather_nd(final_outputs, indices)\n", + " return relu_layer(sequence_ends)\n", + "\n", + "def loss_fn(labels, predictions):\n", + " return tf.reduce_mean((predictions - labels) ** 2)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "JjK4gXFvFsf4", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "The train and test functions are also similar to the ones used in the Eager notebook. Since the network requires a fixed batch size, we'll train in a single shot, rather than by epoch." + ] + }, + { + "metadata": { + "id": "ZWQMExk0S6X6", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", + " iterator = train_data.make_one_shot_iterator()\n", + " step = 0\n", + " while step < num_steps:\n", + " labels, chars, sequence_length = iterator.get_next()\n", + " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=True)\n", + " loss = loss_fn(labels, predictions)\n", + " optimizer.minimize(loss)\n", + " if step % (num_steps // 10) == 0:\n", + " print('Step', step, 'train loss', loss)\n", + " step += 1\n", + " return step\n", + "\n", + "\n", + "def test(eval_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):\n", + " total_loss = 0.0\n", + " iterator = eval_data.make_one_shot_iterator()\n", + " step = 0\n", + " while step < num_steps:\n", + " labels, chars, sequence_length = iterator.get_next()\n", + " predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=False)\n", + " total_loss += loss_fn(labels, predictions)\n", + " step += 1\n", + " print('Test loss', total_loss)\n", + " return total_loss\n", + "\n", + "\n", + "def train_model(train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps):\n", + " optimizer = tf.train.AdamOptimizer(learning_rate=0.01)\n", + "\n", + " train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps=tf.constant(train_steps))\n", + " test(eval_data, lower_cell, upper_cell, relu_layer, 50, num_steps=tf.constant(2))\n", + "\n", + " print('Colorbot is ready to generate colors!\\n\\n')\n", + " \n", + " # In graph mode, every op needs to be a dependent of another op.\n", + " # Here, we create a no_op that will drive the execution of all other code in\n", + " # this function. Autograph will add the necessary control dependencies.\n", + " return tf.no_op()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "iopcs5hXG2od", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, we add code to run inference on a single input, which we'll read from the input.\n", + "\n", + "Note the `do_not_convert` annotation that lets us disable conversion for certain functions and run them as a `py_func` instead, so you can still call them from compiled code." + ] + }, + { + "metadata": { + "id": "DyU0wnnAFEYj", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)\n", + "def draw_prediction(color_name, pred):\n", + " pred = pred * 255\n", + " pred = pred.astype(np.uint8)\n", + " plt.axis('off')\n", + " plt.imshow(pred)\n", + " plt.title(color_name)\n", + " plt.show()\n", + "\n", + "\n", + "def inference(color_name, lower_cell, upper_cell, relu_layer):\n", + " _, chars, sequence_length = parse(color_name)\n", + " chars = tf.expand_dims(chars, 0)\n", + " sequence_length = tf.expand_dims(sequence_length, 0)\n", + " pred = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, 1, training=False)\n", + " pred = tf.minimum(pred, 1.0)\n", + " pred = tf.expand_dims(pred, 0)\n", + " draw_prediction(color_name, pred)\n", + " # Create an op that will drive the entire function.\n", + " return tf.no_op()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "Nt0Kv5OCHip0", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "Finally, we put everything together.\n", + "\n", + "Note that the entire training and testing code is all compiled into a single op (`tf_train_model`) that you only execute once! We also still use a `sess.run` loop for the inference part, because that requires keyboard input." + ] + }, + { + "metadata": { + "id": "-GmWa0GtYWdh", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {}, + {} + ], + "base_uri": "https://localhost:8080/", + "height": 668 + }, + "outputId": "61f4af1d-c81e-44db-9079-1a7b8ed8ce58", + "executionInfo": { + "status": "ok", + "timestamp": 1522345877153, + "user_tz": 240, + "elapsed": 75500, + "user": { + "displayName": "Dan Moldovan", + "photoUrl": "//lh5.googleusercontent.com/-Rneh8xjecyk/AAAAAAAAAAI/AAAAAAAACB4/c5vwsJpbktY/s50-c-k-no/photo.jpg", + "userId": "112023154726779574577" + } + } + }, + "cell_type": "code", + "source": [ + "def run_input_loop(sess, inference_ops, color_name_placeholder):\n", + " \"\"\"Helper function that reads from input and calls the inference ops in a loop.\"\"\"\n", + "\n", + " tb = widgets.TabBar([\"RNN Colorbot\"])\n", + " while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " sess.run(inference_ops, {color_name_placeholder: color_name})\n", + " plt.show()\n", + "\n", + "with tf.Graph().as_default():\n", + " # Read the data.\n", + " batch_size = 64\n", + " train_data = load_dataset(data_dir, train_url, batch_size)\n", + " eval_data = load_dataset(data_dir, test_url, 50, training=False)\n", + " \n", + " # Create the model components.\n", + " lower_cell, upper_cell, relu_layer = model_components()\n", + " # Create the helper placeholder for inference.\n", + " color_name_placeholder = tf.placeholder(tf.string, shape=())\n", + " \n", + " # Compile the train / test code.\n", + " tf_train_model = autograph.to_graph(train_model)\n", + " train_model_ops = tf_train_model(\n", + " train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)\n", + " \n", + " # Compile the inference code.\n", + " tf_inference = autograph.to_graph(inference)\n", + " inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)\n", + " \n", + " with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " \n", + " # Run training and testing.\n", + " sess.run(train_model_ops)\n", + " \n", + " # Run the inference loop.\n", + " run_input_loop(sess, inference_ops, color_name_placeholder)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "('Successfully downloaded', 'train.csv', 28010L, 'bytes.')\n", + "('Successfully downloaded', 'test.csv', 2414L, 'bytes.')\n", + "Step 0 train loss 0.37890616\n", + "Step 10 train loss 0.18515904\n", + "Step 20 train loss 0.0892782\n", + "Step 30 train loss 0.07883155\n", + "Step 40 train loss 0.08585831\n", + "Step 50 train loss 0.09302989\n", + "Step 60 train loss 0.089012615\n", + "Step 70 train loss 0.07275697\n", + "Step 80 train loss 0.06644974\n", + "Step 90 train loss 0.0854013\n", + "Test loss 0.13216865Colorbot is ready to generate colors!\n", + "\n", + "\n", + "\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "
" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b102d936-3379-11e8-ac70-0242ac110002\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"borderColor\": [\"#a7a7a7\"], \"tabNames\": [\"RNN Colorbot\"], \"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"elementId\": \"id1\"});\n", + "//# sourceURL=js_e223a56194" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b103532a-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_b8c6a821fb" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b105b28c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_44805e254b" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b106197a-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_a63d3c6c47" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b1069f44-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b106197a-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7e203b8bce" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"b1070f38-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_d53293d4a7" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6d90d5c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"b105b28c-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_3000dc2c05" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6da872c-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_4136f669a3" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6dac868-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_2f70dd9aee" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6db07d8-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6dac868-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7226726048" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c6dcc6fe-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_72e7709865" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAFZCAYAAADHDNdrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAB9JJREFUeJzt3E1Lle0ax+HTF4jeEAyMBhE0DawI\nwsCH0AIlaGBWNJBo0CDoA0TQhmDXuKAGDioiCA2KlEAlnl05FD9Co8BeaGCQoBDa2jPZsXt4Bvu/\n0+o4Rmvd1zW4rsmP84bFamo0Go0C4H/WvNYHAPhVCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKDy\nUxgeHq5Dhw7V4OBgPXz4sHp7e+vWrVt15cqVOnnyZN2/f78ajUbdvn27+vr6qqenp65du1YrKytV\nVfXhw4e6cOFC9fX1VV9fX01PT1dV1dzcXHV3d9eDBw/q+PHj9ccff9TExMRaXpWfWOtaHwD+zuvX\nr+vOnTs1MTFRbW1tdf78+dW16enpGh8fr/b29hobG6upqal6/Phxbdy4sS5evFgjIyM1NDRUly5d\nqv3799fw8HC9efOmTp8+XVNTU1VV9enTp2pubq5nz57V5ORk3bhxo44dO7ZW1+UnZkJl3Zudna2D\nBw9WR0dHbdiwoQYHB1fX9u7dW+3t7VVV9fLlyxocHKytW7dWa2trnTp1qp4/f16Li4s1MzNT586d\nq6qqXbt21YEDB1an1OXl5Tpx4kRVVe3Zs6fevXv3Yy/IL8OEyrr3+fPnamtrW/2+ffv21c//+Xxh\nYaHu3r1bjx49qqqqlZWVam9vr4WFhWo0GnXmzJnVvYuLi9XV1VVVVS0tLbVp06aqqmpubq6vX7/+\nX+/Dr0tQWfe2bNlSi4uLq98/fvz43X0dHR3V29tbQ0ND3zxfXl6ulpaWevLkSW3evPmbtbm5ufyB\n+W155Wfd6+zsrJmZmZqfn68vX77U2NjYd/cdOXKkxsfHa2lpqaqqRkdH6+nTp9Xa2lqHDx+u0dHR\nqqpaWlqqy5cv1/v373/YHfg9CCrrXmdnZw0MDNTAwECdPXu2enp6vrvv6NGj1dPTUwMDA9Xf318v\nXryo7u7uqqq6evVqzc7OVn9/fw0MDNTOnTtrx44dP/Ia/Aaa/B8qP4NGo1FNTU1VVfXq1au6efPm\nX06qsFZMqKx78/Pz1dXVVW/fvq1Go1GTk5O1b9++tT4W/BcTKj+FkZGRunfvXjU1NdXu3bvr+vXr\ntW3btrU+FnxDUAFCvPIDhAgqQMi6+WH/kX8eXesjAPytf/3jz79cM6EChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCI\noAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIig\nAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAC\nhAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKE\nCCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQI\nKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgq\nQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpA\niKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkCIoAKECCpAiKAChAgqQIigAoQIKkBI\nU6PRaKz1IQB+BSZUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAEEEFCBFUgBBB\nBQgRVIAQQQUIEVSAEEEFCBFUgBBBBQgRVIAQQQUIEVSAkH8D1Aj8lNhhe7QAAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c70592aa-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c6da872c-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_25c3aaf79a" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c70842c0-3379-11e8-ac70-0242ac110002\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_984c56b816" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c708dec4-3379-11e8-ac70-0242ac110002\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_e0451a1217" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7092726-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c708dec4-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_7aa23d7385" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7099044-3379-11e8-ac70-0242ac110002\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_5722756ddb" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + }, + { + "output_type": "stream", + "text": [ + "Give me a color name (or press 'enter' to exit): \n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "window[\"c7baac12-3379-11e8-ac70-0242ac110002\"] = google.colab.output.setActiveOutputArea(window[\"c70842c0-3379-11e8-ac70-0242ac110002\"]);\n", + "//# sourceURL=js_cdd622e58f" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + } + } + ] + }, + { + "metadata": { + "id": "AHJ2c47U-A5W", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "# Where do we go next?\n", + "\n", + "Autograph is available in tensorflow.contrib, but it's still in its early stages. We're excited about the possibilities it brings — write your machine learning code in the flexible Eager style, but still enjoy all the benefits that come with running in graph mode. A beta version will be available soon -- stay tuned!" + ] + } + ] +} diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_colorbot_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_colorbot_estimator.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7f5e4d4ac124f3e9834a87193da110160926e77e --- /dev/null +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_colorbot_estimator.ipynb @@ -0,0 +1,1421 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "LqNpENf-ec0X", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "!pip install -U tf-nightly" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Pa2qpEmoVOGe", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow.contrib import autograph\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import six\n", + "\n", + "from google.colab import widgets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HNqUFL4deCsL", + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Case study: building an RNN\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YkC1k4HEQ7rw", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "In this section, we show how you can use AutoGraph to build RNNColorbot, an RNN that takes as input names of colors and predicts their corresponding RGB tuples. The model will be trained by a [custom Estimator](https://www.tensorflow.org/get_started/custom_estimators)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7nkPDl5CTCNb", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "To get started, set up the dataset. The following cells defines methods that download and format the data needed for RNNColorbot; the details aren't important (read them in the privacy of your own home if you so wish), but make sure to run the cells before proceeding." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "A0uREmVXCQEw", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "def parse(line):\n", + " \"\"\"Parses a line from the colors dataset.\"\"\"\n", + " items = tf.string_split([line], \",\").values\n", + " rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n", + " color_name = items[0]\n", + " chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n", + " length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n", + " return rgb, chars, length\n", + "\n", + "def load_dataset(data_dir, url, batch_size, training=True):\n", + " \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n", + " path = tf.keras.utils.get_file(os.path.basename(url), url, cache_dir=data_dir)\n", + " dataset = tf.data.TextLineDataset(path)\n", + " dataset = dataset.skip(1)\n", + " dataset = dataset.map(parse)\n", + " dataset = dataset.cache()\n", + " dataset = dataset.repeat()\n", + " if training:\n", + " dataset = dataset.shuffle(buffer_size=3000)\n", + " dataset = dataset.padded_batch(\n", + " batch_size, padded_shapes=([None], [None, None], []))\n", + " return dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "waZ89t3DTUla", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "To show the use of control flow, we write the RNN loop by hand, rather than using a pre-built RNN model.\n", + "\n", + "Note how we write the model code in Eager style, with regular `if` and `while` statements. Then, we annotate the functions with `@autograph.convert` to have them automatically compiled to run in graph mode." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "9v8AJouiC44V", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "class RnnColorbot(object):\n", + " \"\"\"Holds the parameters of the colorbot model.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", + " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", + " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", + "\n", + " self.lower_cell.build(tf.TensorShape((None, 256)))\n", + " self.upper_cell.build(tf.TensorShape((None, 256)))\n", + " self.relu_layer.build(tf.TensorShape((None, 128)))\n", + "\n", + "\n", + "def rnn_layer(chars, cell, batch_size, training):\n", + " \"\"\"A simple RNN layer.\n", + " \n", + " Args:\n", + " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", + " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", + " \"\"\"\n", + " hidden_outputs = []\n", + " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " state, output = cell.zero_state(batch_size, tf.float32)\n", + " for ch in chars:\n", + " cell_output, (state, output) = cell.call(ch, (state, output))\n", + " hidden_outputs.append(cell_output)\n", + " hidden_outputs = hidden_outputs.stack()\n", + " if training:\n", + " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", + " return hidden_outputs\n", + "\n", + "\n", + "@autograph.convert(recursive=True)\n", + "def model(inputs, colorbot, batch_size, training):\n", + " \"\"\"RNNColorbot model.\n", + " \n", + " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", + " followed by a fully connected layer with ReLU activation.\n", + " \n", + " Args:\n", + " inputs: A tuple (chars, length)\n", + " colorbot: An object of type RnnColorbot\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + " \n", + " Returns:\n", + " A Tensor of shape (batch_size, 3) - the model predictions.\n", + " \"\"\"\n", + " (chars, length) = inputs\n", + " seq = tf.transpose(chars, [1, 0, 2])\n", + " seq.set_shape((None, batch_size, 256))\n", + "\n", + " seq = rnn_layer(seq, colorbot.lower_cell, batch_size, training)\n", + " seq = rnn_layer(seq, colorbot.upper_cell, batch_size, training)\n", + "\n", + " # Grab just the end-of-sequence from each output.\n", + " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " sequence_ends = tf.gather_nd(seq, indices)\n", + " return colorbot.relu_layer(sequence_ends)\n", + "\n", + "@autograph.convert()\n", + "def loss_fn(labels, predictions):\n", + " return tf.reduce_mean((predictions - labels) ** 2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JjK4gXFvFsf4", + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "We will now create the model function for the estimator.\n", + "\n", + "In the model function, we simply call the converted functions that we defined above - that's it!" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "-yso_Nx23Gy1", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "def model_fn(features, labels, mode, params):\n", + " \"\"\"Estimator model function.\"\"\"\n", + " chars = features['chars']\n", + " sequence_length = features['sequence_length']\n", + " inputs = (chars, sequence_length)\n", + "\n", + " # Create the model components.\n", + " # Simply calling the AutoGraph-ed functions and objects just works!\n", + " colorbot = RnnColorbot()\n", + " \n", + " batch_size = params['batch_size']\n", + "\n", + " if mode == tf.estimator.ModeKeys.TRAIN:\n", + " predictions = model(inputs, colorbot, batch_size, training=True)\n", + " loss = loss_fn(labels, predictions)\n", + "\n", + " learning_rate = params['learning_rate']\n", + " optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n", + " global_step = tf.train.get_global_step()\n", + " train_op = optimizer.minimize(loss, global_step=global_step)\n", + " return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)\n", + "\n", + " elif mode == tf.estimator.ModeKeys.EVAL:\n", + " predictions = model(inputs, colorbot, batch_size, training=False)\n", + " loss = loss_fn(labels, predictions)\n", + "\n", + " return tf.estimator.EstimatorSpec(mode, loss=loss)\n", + " \n", + " elif mode == tf.estimator.ModeKeys.PREDICT:\n", + " # For prediction, we expect single tensors.\n", + " predictions = model(inputs, colorbot, 1, training=False)\n", + "\n", + " predictions = tf.minimum(predictions, 1.0)\n", + " return tf.estimator.EstimatorSpec(mode, predictions=predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "HOQfoBnHC9CP", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "We'll create an input function that will feed our training and eval data." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "FJZlx7yG2MP0", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "def input_fn(data_dir, data_url, params, training=True):\n", + " \"\"\"An input function for training\"\"\"\n", + " batch_size = params['batch_size']\n", + " \n", + " # load_dataset defined above\n", + " dataset = load_dataset(data_dir, data_url, batch_size, training=training)\n", + "\n", + " # Package the pipeline end in a format suitable for the estimator.\n", + " labels, chars, sequence_length = dataset.make_one_shot_iterator().get_next()\n", + " features = {\n", + " 'chars': chars,\n", + " 'sequence_length': sequence_length\n", + " }\n", + "\n", + " return features, labels" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qsvv-lzbDqXd", + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "We now have everything in place to build our custom estimator and use it for training and eval!" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 35 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 10064, + "status": "ok", + "timestamp": 1523580419240, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "2pg1AfbxBJQq", + "outputId": "41894b16-3d3a-4e30-f6e4-5a9c837a2210", + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval loss at step 100: 0.0665446\n" + ] + } + ], + "source": [ + "params = {\n", + " 'batch_size': 64,\n", + " 'learning_rate': 0.01,\n", + "}\n", + "\n", + "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n", + "data_dir = \"tmp/rnn/data\"\n", + "\n", + "regressor = tf.estimator.Estimator(\n", + " model_fn=model_fn,\n", + " params=params)\n", + "\n", + "regressor.train(\n", + " input_fn=lambda: input_fn(data_dir, train_url, params),\n", + " steps=100)\n", + "eval_results = regressor.evaluate(\n", + " input_fn=lambda: input_fn(data_dir, test_url, params, training=False),\n", + " steps=2\n", + ")\n", + "\n", + "print('Eval loss at step %d: %s' % (eval_results['global_step'], eval_results['loss']))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zG1YAjB_cUnQ", + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "And here's the same estimator used for inference." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 343 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 31286, + "status": "ok", + "timestamp": 1523580450579, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "dxHex2tUN_10", + "outputId": "b3dc558d-b800-4e9b-e60e-3441124e80d8", + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f4112527e90\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f4112527f10\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f4112527f50\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"2c60f474-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"borderColor\": [\"#a7a7a7\"], \"contentBorder\": [\"0px\"], \"tabNames\": [\"RNN Colorbot\"], \"elementId\": \"id1\"});\n", + "//# sourceURL=js_a0db480422" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd1d0\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"2c60f475-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_d2a46ea291" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd0d0\u003e" + ] + }, + "metadata": { + "tags": [ + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"2c60f476-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_0a8262c6e9" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd390\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"2c60f477-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_e32f85ccd2" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd490\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"2c60f478-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"2c60f477-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_eaee748b21" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd550\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"2c60f479-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_2befe06587" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f4112527f10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b1a-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"2c60f476-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_8ec4aeeb25" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd690\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b1b-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_9f9f4574f1" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd350\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b1c-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_bcccd8f300" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd6d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b1d-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b1c-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_2c056cee72" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd490\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b1e-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_c853c3f58b" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd610\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b1f-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b1b-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_e5730ab00d" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2050\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b20-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_a897ef7e24" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2250\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b21-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_565fa3d154" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f4113124d90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b22-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b21-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_222e0dc6af" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f4113124c10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"354d7b23-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_831db7458f" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f4113124310\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fab4-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b20-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_adb576c6eb" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f990850\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fab5-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_9418f2d32f" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f990850\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fab6-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_3fad25f306" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f4112527ed0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fab7-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803fab6-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_45b9340e7b" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f990c90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fab8-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_bec9896d44" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f990a10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fab9-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803fab5-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_460b91ad4a" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3a10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803faba-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_7dedd0b037" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3890\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fabb-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_4b1c977dc7" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3bd0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fabc-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803fabb-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_d64fedfcf9" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3410\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3803fabd-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_3e8c929c3f" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3c50\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b986c-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803faba-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_9f9cf2b76f" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd590\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b986d-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_b402e6b587" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3d90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b986e-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_9b7d66db72" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3b10\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b986f-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b986e-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_11ec213a3f" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3950\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9870-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_9c055e4bc0" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41b21d3850\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMRJREFUeJzt3F+IlfW+x/Gvp3FECyIqU4PCO7EgZnQtnUJ0JJGoTDoY\n/dGrMBJhosggIgK7KwwiMdxRF11F/0AJvIisLBqcguxCjEAkmNQGcRvVwIzm71zsc4Yje7P3x9h7\nz97u1+tqrYdnPeu7nos3v2f9m9FaawUQ+K/pHgD49yEYQEwwgJhgADHBAGKCAcQEg2nx9NNPV7fb\nrfvuu69GRkZq5cqV0z0SAcG4xK1evbqGh4ene4wLfPXVVzU8PFyfffZZvf3221VVNWPGjGmeioRg\n8E/122+/1Q8//FDXX399zZo1a7rH4SIJxiXsqaeeqhMnTtSWLVuqv7+/Xn/99frmm2/q/vvvr06n\nU+vXr6+RkZGp/Tdt2lQvv/xyPfDAA9Xf318PP/xwnTlzpqqqJicna9u2bbVs2bLqdDq1YcOGOn36\ndFVVjY2N1ZYtW2rZsmW1du3aeuedd6aOuXPnzhoaGqpt27bV0qVL67333qtnn322Dh06VP39/bVz\n584/m/vo0aO1adOm6nQ6dffdd9f+/furqmp0dLQ6nc7Ufs8880zdeuutU/e3bdtWb7755t/3JHKh\nxiVtcHCwDQ8Pt9ZaO3nyZOt2u+3AgQOttda++OKL1u122+nTp1trrW3cuLGtWbOmff/9921iYqJt\n3Lix7dixo7XW2ltvvdUeffTRNjEx0c6fP98OHz7cfvnll9Zaaw899FDbvn17m5ycbEeOHGnLly+f\nes5XXnml3XTTTe2jjz5qrbU2MTHR3n///fbggw9OzXjw4MG2cuXK1lprZ8+ebWvWrGm7d+9uZ8+e\nbcPDw62vr68dO3Zs6vUcPny4tdba2rVr2+23396OHj3aWmtt1apV7ciRI/+oU0lrzQrjP0D7358L\n7d27t1atWlUrVqyoqqqBgYG6+eab69NPP53a9957760bbrihent764477qgjR45UVVVPT0+dOXOm\njh07VjNmzKjFixfX5ZdfXidPnqyvv/66nnzyyZo5c2YtWrSoNmzYUHv27Jk6Zl9fX61evbqqqnp7\ne//qrIcOHarx8fF65JFHqqenp5YvX16Dg4P1wQcfVFXV0qVLa2RkpE6dOlVVVWvXrq0vv/yyRkdH\n69dff61Fixb9nc4af0nPdA/AP8/x48dr37599fHHH1fVn0Jy7ty5GhgYmNrnmmuumbo9e/bsGh8f\nr6qqe+65p06ePFlPPPFE/fzzz7Vu3bp6/PHHa2xsrK688sqaPXv21OMWLFhQhw8fnro/b968eMax\nsbGaP3/+BdsWLFhQY2NjVVXV6XRq//79dd1111W3261ut1t79uyp3t7eWrJkyUWcDX4PwbjE/f9P\nH+bPn1/r16+v7du3X/Rxenp6auvWrbV169Y6fvx4bd68uRYuXFi33XZb/fTTTzU+Pl5z5sypqqoT\nJ07U3Llz/+IMf8vcuXPrxIkTF2w7fvx4LVy4sKqqut1uvfjiizV//vzqdDrV399fzz33XPX29la3\n273o18XFcUlyibv22mtrdHS0qqrWrVtX+/fvr88//7zOnz9fExMTNTIyUj/++OPfPM7Bgwfru+++\nq/Pnz9ecOXOqp6enLrvsspo3b1719fXVSy+9VJOTk/Xtt9/Wu+++W+vWrftd895yyy01Z86ceu21\n1+rcuXN18ODB+uSTT+rOO++sqqobb7yxZs2aVXv37q1Op1NXXHFFXX311fXhhx9e8IYo/xiCcYnb\nvHlz7dq1q7rdbu3bt6927dpVu3fvroGBgRocHKw33nhj6j2Ov7YSOHXqVA0NDdWSJUvqrrvuqmXL\nlk1FYceOHTU6OlorVqyooaGheuyxxy64zLkYM2fOrFdffbUOHDhQy5cvr+eff75eeOGFqRVG1Z9W\nGVddddXUpc7/hWLx4sW/6znJzWjNH+gAGSsMICYYQEwwgJhgALF/2e9h/PEP/z3dI8B/tKseee/P\ntllhADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwg\nJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICY\nYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKC\nAcQEA4gJBhATDCA2o7XWpnsI4N+DFQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE/gfh60wGjfc7LQAAAABJRU5ErkJg\ngg==\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f4113124310\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9871-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b986d-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_ba6a061307" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd890\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9872-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_83e3496927" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd590\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9873-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_f437bab20d" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a22d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9874-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b9873-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_93aa63450e" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2b90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9875-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_aca189bea5" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd4d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\u003cdiv class=id_100313201 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + ], + "text/plain": [ + "\u003cIPython.core.display.HTML at 0x7f410f990a90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9876-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_100313201 span\");\n", + "//# sourceURL=js_5df1fe383e" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f410f8fd490\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3b9b9877-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"3b9b9876-3eb4-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_c62c7174ad" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2390\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3ed76584-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_100313201 input\");\n", + "//# sourceURL=js_2e2201ddc4" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2810\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3ed76585-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"3ed76584-3eb4-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_288e5283d6" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a26d0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3ed76586-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_100313201 span\");\n", + "//# sourceURL=js_2f31d19cde" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2fd0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3ed76587-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"3ed76586-3eb4-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_2fbbcda050" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f4112527e90\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1", + "user_output" + ] + }, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "window[\"3ed76588-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b9872-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_f94d975cf3" + ], + "text/plain": [ + "\u003cIPython.core.display.Javascript at 0x7f41127a2fd0\u003e" + ] + }, + "metadata": { + "tags": [ + "id1_content_0", + "outputarea_id1" + ] + }, + "output_type": "display_data" + } + ], + "source": [ + "def predict_input_fn(color_name):\n", + " \"\"\"An input function for prediction.\"\"\"\n", + " _, chars, sequence_length = parse(color_name)\n", + " \n", + " # We create a batch of a single element.\n", + " features = {\n", + " 'chars': tf.expand_dims(chars, 0),\n", + " 'sequence_length': tf.expand_dims(sequence_length, 0)\n", + " }\n", + " return features, None\n", + "\n", + "\n", + "def draw_prediction(color_name, pred):\n", + " pred = pred * 255\n", + " pred = pred.astype(np.uint8)\n", + " plt.axis('off')\n", + " plt.imshow(pred)\n", + " plt.title(color_name)\n", + " plt.show()\n", + "\n", + "\n", + "def predict_with_estimator(color_name, regressor):\n", + " predictions = regressor.predict(\n", + " input_fn=lambda:predict_input_fn(color_name))\n", + " pred = next(predictions)\n", + " predictions.close()\n", + " pred = np.minimum(pred, 1.0)\n", + " pred = np.expand_dims(np.expand_dims(pred, 0), 0)\n", + "\n", + " draw_prediction(color_name, pred)\n", + "\n", + "tb = widgets.TabBar([\"RNN Colorbot\"])\n", + "while True:\n", + " with tb.output_to(0):\n", + " try:\n", + " color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n", + " except (EOFError, KeyboardInterrupt):\n", + " break\n", + " if not color_name:\n", + " break\n", + " with tb.output_to(0):\n", + " tb.clear_tab()\n", + " predict_with_estimator(color_name, regressor)\n", + " " + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "RNN Colorbot using Estimators", + "provenance": [ + { + "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", + "timestamp": 1523579810961 + }, + { + "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG", + "timestamp": 1523016192637 + }, + { + "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K", + "timestamp": 1522238054357 + }, + { + "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ", + "timestamp": 1521743157199 + }, + { + "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-", + "timestamp": 1520522344607 + } + ], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD similarity index 74% rename from tensorflow/contrib/py2tf/impl/BUILD rename to tensorflow/contrib/autograph/impl/BUILD index 90ffabbc9bf4524ec2ebf54b6dd847bd8768a486..54424e26472b8466b8fe68ea848b5463c10224c9 100644 --- a/tensorflow/contrib/py2tf/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -25,10 +25,11 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "//tensorflow/contrib/py2tf/converters", - "//tensorflow/contrib/py2tf/pyct", - "//tensorflow/contrib/py2tf/pyct/static_analysis", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", "@six_archive//:six", ], @@ -38,10 +39,12 @@ py_test( name = "api_test", srcs = ["api_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", - "//tensorflow/contrib/py2tf/utils", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", + "//third_party/py/numpy", ], ) @@ -49,6 +52,7 @@ py_test( name = "conversion_test", srcs = ["conversion_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":impl", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py new file mode 100644 index 0000000000000000000000000000000000000000..3c3130c77025c45ca219daf4bb66082f4e8a7f82 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/api.py @@ -0,0 +1,296 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Public API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +from enum import Enum + +# pylint:disable=g-bad-import-order +import gast +import six +# pylint:enable=g-bad-import-order + +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_inspect + +# TODO(mdan): Properly document the type hints. +# TODO(mdan): Reduce the type hint information to (module, type). +# (currently we require (module + class name, type)) + + +def convert(recursive=False, verbose=False, arg_types=None): + """Decorator that compiles a function to graph mode. + + The decorator is dynamic - invoking compilation whenever the decorated + function is called. This means the parameter values are known at compilation. + + Args: + recursive: Whether to recursively convert any functions that the decorator + function may call. + verbose: Whether to output the compiled code in the logs. + arg_types: See to_graph. + + Returns: + A decorator that compiles the given function to graph mode. + + Raises: + ValueError: If any of the arguments are illegal. + """ + if arg_types is None: + arg_types = {} + + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def wrapper(*args, **kwargs): + return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +class RunMode(Enum): + GRAPH = 1 + PY_FUNC = 2 + + +def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): + """Decorator that suppresses compilation of a function. + + Args: + run_as: RunMode value. Whether to run the function as-is, or wrap it into + a py_func. + return_dtypes: See autograph.utils.py_func.wrap_py_func. Setting to None or + empty list or tuple will create a dummy return value that can be used + to set control dependencies. + + Returns: + A decorator that wraps the original function. + """ + def decorator(f): + """Decorator implementation.""" + + @wraps(f) + def graph_wrapper(*args, **kwargs): + return f(*args, **kwargs) + + @wraps(f) + def py_func_wrapper(*args, **kwargs): + if kwargs: + raise NotImplementedError( + 'RunMode.PY_FUNC does not yet support kwargs') + # TODO(mdan): Add support for kwargs. + return py_func.wrap_py_func( + f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes) + + if run_as == RunMode.GRAPH: + wrapper = graph_wrapper + elif run_as == RunMode.PY_FUNC: + wrapper = py_func_wrapper + else: + raise ValueError('unknown value for run_as: %s' % run_as) + + # Sometimes the decorator is just desugared, making it impossible to detect. + # This attribute makes detection easier. + setattr(wrapper, '__pyct_is_compile_decorator', True) + return wrapper + + return decorator + + +def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): + """Compiles a function call inline.""" + # TODO(mdan): This needs cleanup. + # In particular, we may want to avoid renaming functions altogether. + + if conversion.is_whitelisted_for_graph(f): + return f(*args, **kwargs) + + unknown_arg_value = object() # Sentinel for arguments of unknown value + + if inspect_utils.isbuiltin(f): + return builtins.dynamic_builtin(f, *args, **kwargs) + + if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): + # Regular functions + target_entity = f + arg_map_target = f + effective_args = args + f_class = inspect_utils.getmethodclass(f) + + if f_class is not None: + partial_types = (f_class,) + else: + partial_types = () + + elif tf_inspect.isclass(f): + # Constructors + target_entity = f + arg_map_target = f.__init__ + effective_args = args + partial_types = () + + elif hasattr(f, '__call__') and hasattr(f, '__class__'): + # Callable objects + target_entity = f.__call__ + arg_map_target = f.__call__ + effective_args = (f,) + args + partial_types = (f.__class__,) + + else: + NotImplementedError('unknown callable type "%s"' % type(f)) + + arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) + for name, arg in arg_values.items(): + if arg is unknown_arg_value: + continue + arg_class = arg.__class__ + # If arg_value_hints specifies any name, use that instead. + if name not in arg_types: + arg_types[name] = (arg_class.__name__, arg_class) + + # When called from within a decorator, this is the only indication that + # the function is a method - it appears that the decorator is applied + # before the method is bound. + if not partial_types: + if 'self' in arg_values: + if tf_inspect.isclass(arg_values['self'].__class__): + partial_types = (arg_values['self'].__class__,) + elif 'cls' in arg_values: + if tf_inspect.isclass(arg_values['cls']): + partial_types = (arg_values['cls'],) + + converted_f = to_graph( + target_entity, + recursive=recursive, + verbose=verbose, + arg_values=arg_values, + arg_types=arg_types, + partial_types=partial_types) + return converted_f(*effective_args, **kwargs) + + +def to_graph(e, + recursive=True, + verbose=False, + arg_values=None, + arg_types=None, + partial_types=None): + """Compile a Python entity into equivalent TensorFlow code. + + Currently supported entities: + * functions + * classes + + Classes are handled by converting all their methods into a new class. + + Args: + e: A Python entity. + recursive: Whether to recursively convert any functions that the decorator + function may call. + verbose: Whether to output the compiled code in the logs. + arg_values: A dict containing value hints for symbols like function + parameters. + arg_types: A dict containing type hints for symbols like function + parameters. + partial_types: A set of types (e.g. classes) that will not be converted + entirely. Calls to member functions for these types will be renamed + independently. + + Returns: + A function with a signature identical to `o`, but which when executed it + creates TF a graph that has the same functionality as the original entity. + """ + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, do_not_convert, converted_call), + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) + _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values, + arg_types) + + module = gast.Module([]) + for import_line in config.COMPILED_IMPORT_STATEMENTS: + module.body.extend(parser.parse_str(import_line).body) + for dep in conversion_map.dependency_cache.values(): + module.body.append(dep) + compiled_node, compiled_src = compiler.ast_to_object(module) + + # The compiled code should see everything the entry entity saw. + # TODO(mdan): This might not work well if the call tree spans modules? + for key, val in namespace.items(): + # Avoid overwriting entities that have been transformed. + if key not in compiled_node.__dict__: + compiled_node.__dict__[key] = val + compiled_fn = getattr(compiled_node, name) + + if verbose: + logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) + + return compiled_fn + + +def to_code(e, + recursive=True, + arg_values=None, + arg_types=None, + partial_types=None, + indentation=' '): + """Return the equivalent of an entity in TensorFlow code. + + See `to_graph` for more details. + + Args: + e: A Python entity. + recursive: See to_graph. + arg_values: See to_graph. + arg_types: See to_graph. + partial_types: See to_graph. + indentation: String, when to use for each level of indentation. + + Returns: + String. + """ + conversion_map = conversion.ConversionMap( + recursive=recursive, + nocompile_decorators=(convert, do_not_convert, converted_call), + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) + conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) + + imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) + code = '\n'.join( + compiler.ast_to_source(dep, indentation) + for dep in reversed(tuple( + six.itervalues(conversion_map.dependency_cache)))) + + return imports + '\n\n' + code diff --git a/tensorflow/contrib/py2tf/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py similarity index 55% rename from tensorflow/contrib/py2tf/impl/api_test.py rename to tensorflow/contrib/autograph/impl/api_test.py index 02cd8ed2d0ffee8ef2d31ea65902d2b493df9d64..a7737b7f448131b1c54951efa719b481e1f4d0c9 100644 --- a/tensorflow/contrib/py2tf/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -18,23 +18,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import api -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.pyct import parser +import numpy as np + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op -from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +tf = utils.fake_tf() + + class ApiTest(test.TestCase): def setUp(self): - config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) config.COMPILED_IMPORT_STATEMENTS = ( - 'from tensorflow.python.ops ' - 'import control_flow_ops as tf', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') + 'from __future__ import print_function', + 'from tensorflow.contrib.autograph import utils' + ' as autograph_utils', + 'tf = autograph_utils.fake_tf()', + ) def test_decorator_recurses(self): @@ -47,7 +53,7 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -63,11 +69,11 @@ class ApiTest(test.TestCase): class TestClass(object): def called_member(self, a): - return math_ops.negative(a) + return tf.negative(a) @api.convert(recursive=False) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -78,17 +84,17 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_converted(self): + def test_decorator_calls_unconverted_graph(self): class TestClass(object): - @api.graph_ready + @api.do_not_convert(api.RunMode.GRAPH) def called_member(self, a): - return math_ops.negative(a) + return tf.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= self.called_member(a) return x @@ -99,20 +105,23 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_decorator_calls_decorated(self): + def test_decorator_calls_unconverted_py_func(self): class TestClass(object): - @api.convert() + @api.do_not_convert( + api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1)) def called_member(self, a): - if a < 0: - a = -a - return a + return np.negative(a) @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= self.called_member(a) + while tf.reduce_sum(x) > s: + y = self.called_member(a) + # set_shape works around while_loop's limitations. + # TODO(mdan): Allow specifying shapes (or ShapeLike) instead. + y.set_shape(a.shape) + x //= y return x tc = TestClass() @@ -122,10 +131,11 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_convert_call_site_decorator(self): + def test_decorator_calls_decorated(self): class TestClass(object): + @api.convert() def called_member(self, a): if a < 0: a = -a @@ -133,8 +143,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= api.convert_inline(self.called_member, a) + while tf.reduce_sum(x) > s: + x //= self.called_member(a) return x tc = TestClass() @@ -144,17 +154,20 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) - def test_graph_ready_call_site_decorator(self): + def test_convert_call_site_decorator(self): class TestClass(object): def called_member(self, a): - return math_ops.negative(a) + if a < 0: + a = -a + return a @api.convert(recursive=True) def test_method(self, x, s, a): - while math_ops.reduce_sum(x) > s: - x //= api.graph_ready(self.called_member(a)) + while tf.reduce_sum(x) > s: + x //= api.converted_call(self.called_member, False, False, {}, self, + a) return x tc = TestClass() @@ -164,9 +177,96 @@ class ApiTest(test.TestCase): constant_op.constant(-2)) self.assertListEqual([0, 1], sess.run(x).tolist()) + def test_converted_call_builtin(self): + x = api.converted_call(range, False, False, {}, 3) + self.assertEqual((0, 1, 2), tuple(x)) + + def test_converted_call_function(self): + + def test_fn(x): + if x < 0: + return -x + return x + + with self.test_session() as sess: + x = api.converted_call( + test_fn, False, False, {}, constant_op.constant(-1)) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_method(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def test_method(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = TestClass(constant_op.constant(-1)) + x = api.converted_call(tc.test_method, False, False, {}, tc) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_method_by_class(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def test_method(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = TestClass(constant_op.constant(-1)) + x = api.converted_call(TestClass.test_method, False, False, {}, tc) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_callable_object(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def __call__(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = TestClass(constant_op.constant(-1)) + x = api.converted_call(tc, False, False, {}) + self.assertEqual(1, sess.run(x)) + + def test_converted_call_constructor(self): + + class TestClass(object): + + def __init__(self, x): + self.x = x + + def test_method(self): + if self.x < 0: + return -self.x + return self.x + + with self.test_session() as sess: + tc = api.converted_call( + TestClass, False, False, {}, constant_op.constant(-1)) + # tc is now a converted object. + x = tc.test_method() + self.assertEqual(1, sess.run(x)) + def test_to_graph_basic(self): + def test_fn(x, s): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x //= 2 return x @@ -177,15 +277,15 @@ class ApiTest(test.TestCase): self.assertListEqual([1, 2], sess.run(x).tolist()) def test_to_code_basic(self): + def test_fn(x, s): - while math_ops.reduce_sum(x) > s: + while tf.reduce_sum(x) > s: x /= 2 return x compiled_code = api.to_code(test_fn) - # Just check for some key words and that it is parseable Python code. - self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while') + # Just check that it is parseable Python code. self.assertIsNotNone(parser.parse_str(compiled_code)) diff --git a/tensorflow/contrib/py2tf/impl/config.py b/tensorflow/contrib/autograph/impl/config.py similarity index 73% rename from tensorflow/contrib/py2tf/impl/config.py rename to tensorflow/contrib/autograph/impl/config.py index 7c3ecefff0f8858d5505ff30e1270b2fd42c9ad8..2600088595a12761b1138c4649c06882bd8fd000 100644 --- a/tensorflow/contrib/py2tf/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf import utils +from tensorflow.contrib.autograph import utils PYTHON_LITERALS = { @@ -31,15 +31,19 @@ PYTHON_LITERALS = { DEFAULT_UNCOMPILED_MODULES = set(( ('tensorflow',), (utils.__name__,), + + # All of tensorflow's subpackages. Unlike the root tf module, they don't + # have well-known names. Not refering to the module directly to avoid + # circular imports. + ( + utils.__name__[:-len('.contrib.autograph.utils')],), )) NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) -# TODO(mdan): Also allow controlling the generated names (for testability). -# TODO(mdan): Verify that these names are not hidden by generated code. -# TODO(mdan): Make sure copybara renames the reference below. +# TODO(mdan): Also allow controlling the generated names. +# TODO(mdan); Consolidate all internal imports into a single __ag module. COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', 'import tensorflow as tf', - 'from tensorflow.contrib.py2tf import utils as ' - 'py2tf_utils') +) diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py similarity index 56% rename from tensorflow/contrib/py2tf/impl/conversion.py rename to tensorflow/contrib/autograph/impl/conversion.py index 3d5624b187ed47e9eed8afbb2e101e1098f81c15..c9fb972953ea315006b9f3005a5d3df5a2f043d3 100644 --- a/tensorflow/contrib/py2tf/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -18,28 +18,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import imp + import gast -import six - -from tensorflow.contrib.py2tf import utils -from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import break_statements -from tensorflow.contrib.py2tf.converters import builtin_functions -from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import continue_statements -from tensorflow.contrib.py2tf.converters import control_flow -from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.converters import for_loops -from tensorflow.contrib.py2tf.converters import logical_expressions -from tensorflow.contrib.py2tf.converters import side_effect_guards -from tensorflow.contrib.py2tf.impl import config -from tensorflow.contrib.py2tf.impl import naming -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info + +from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import asserts +from tensorflow.contrib.autograph.converters import break_statements +from tensorflow.contrib.autograph.converters import builtin_functions +from tensorflow.contrib.autograph.converters import call_trees +from tensorflow.contrib.autograph.converters import continue_statements +from tensorflow.contrib.autograph.converters import control_flow +from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info +from tensorflow.contrib.autograph.utils import type_hints from tensorflow.python.util import tf_inspect @@ -47,27 +54,37 @@ from tensorflow.python.util import tf_inspect class ConversionMap(object): - """ConversionMaps keep track of converting function hierarchies. + """ConversionMap keeps track of converting function hierarchies. + + This object is mutable, and is updated as functions are converted. Attributes: - recursive: Whether to recusrively convert any functions that the decorator + recursive: Whether to recursively convert any functions that the decorator function may call. nocompile_decorators: tuple of decorator functions that toggle compilation off. dependency_cache: dict[object]: ast; maps original entities to their converted AST + additional_imports: set(object); additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code name_map: dict[string]: string; maps original entities to the name of their converted counterparts + api_module: A reference to the api module. The reference needs to be passed + to avoid circular dependencies. """ # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - def __init__(self, recursive, nocompile_decorators, partial_types): + def __init__(self, recursive, nocompile_decorators, partial_types, + api_module): self.recursive = recursive self.nocompile_decorators = nocompile_decorators self.partial_types = partial_types if partial_types else () self.dependency_cache = {} + self.additional_imports = set() self.name_map = {} + self.api_module = api_module def new_namer(self, namespace): return naming.Namer(namespace, self.recursive, self.name_map, @@ -88,6 +105,24 @@ class ConversionMap(object): self.dependency_cache[original_entity] = converted_ast +def is_whitelisted_for_graph(o): + """Check whether an entity is whitelisted for use in graph mode. + + Examples of whitelisted entities include all members of the tensorflow + package. + + Args: + o: A Python entity. + Returns: + Boolean + """ + m = tf_inspect.getmodule(o) + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + return True + return False + + def entity_to_graph(o, conversion_map, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. @@ -106,20 +141,22 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): parameters. Returns: - A tuple (ast, new_name): + A tuple (ast, new_name, namespace): * ast: An AST representing an entity with interface equivalent to `o`, but which when executed it creates TF a graph. * new_name: The symbol name under which the new entity can be found. + * namespace: A dict mapping all symbols visible to the converted entity, + keyed by their symbol name. Raises: ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): - node, new_name = class_to_graph(o, conversion_map) + node, name, ns = class_to_graph(o, conversion_map) elif tf_inspect.isfunction(o): - node, new_name = function_to_graph(o, conversion_map, arg_values, arg_types) + node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) elif tf_inspect.ismethod(o): - node, new_name = function_to_graph(o, conversion_map, arg_values, arg_types) + node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' @@ -127,47 +164,76 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): conversion_map.add_to_cache(o, node) if conversion_map.recursive: - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: - if (hasattr(obj, 'im_class') and - getattr(obj, 'im_class') not in conversion_map.partial_types): - # Class members are converted with their objects, unless they're - # only converted partially. - continue - entity_to_graph(obj, conversion_map, {}, {}) - - return node, new_name + while True: + candidate = None + for obj in conversion_map.name_map.keys(): + if obj not in conversion_map.dependency_cache: + candidate = obj + break + if candidate is None: + break + if (hasattr(candidate, 'im_class') and + getattr(candidate, 'im_class') not in conversion_map.partial_types): + # Class members are converted with their objects, unless they're + # only converted partially. + continue + entity_to_graph(candidate, conversion_map, {}, {}) + + return node, name, ns def class_to_graph(c, conversion_map): """Specialization of `entity_to_graph` for classes.""" converted_members = {} - members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod) + method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) + members = tf_inspect.getmembers(c, predicate=method_filter) if not members: - raise ValueError('Cannot convert %s: it has no member methods.') + raise ValueError('Cannot convert %s: it has no member methods.' % c) - class_globals = None + class_namespace = {} for _, m in members: - node, _ = function_to_graph( + node, _, namespace = function_to_graph( m, conversion_map=conversion_map, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) - # TODO(mdan): Do not assume all members have the same view of globals. - if class_globals is None: - class_globals = six.get_function_globals(m) + if class_namespace is None: + class_namespace = namespace + else: + class_namespace.update(namespace) converted_members[m] = node - namer = conversion_map.new_namer(class_globals) + namer = conversion_map.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) node = gast.ClassDef( class_name, bases=[], keywords=[], - body=converted_members.values(), + body=list(converted_members.values()), decorator_list=[]) - return node, class_name + return node, class_name, class_namespace + + +def _add_reserved_symbol(namespace, name, entity): + if name not in namespace: + namespace[name] = entity + elif namespace[name] != entity: + raise ValueError('The name "%s" is reserved and may not be used.' % name) + + +def _add_self_references(namespace, api_module): + # Craft a module that exposes parts of the external API as well as certain + # internal modules. + ag_internal = imp.new_module('autograph') + ag_internal.converted_call = api_module.converted_call + ag_internal.utils = utils + # TODO(mdan): Add safeguards against name clashes. + # We don't want to create a submodule because we want the operators to be + # accessible as ag__. + ag_internal.__dict__.update(operators.__dict__) + + _add_reserved_symbol(namespace, 'ag__', ag_internal) def function_to_graph(f, conversion_map, arg_values, arg_types, @@ -175,24 +241,11 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] - namespace = six.get_function_globals(f) - - # This is needed for non-global functions. - closure = six.get_function_closure(f) - if closure: - for e in closure: - if callable(e.cell_contents): - fn = e.cell_contents - namespace[fn.__name__] = fn - - # Manually add the utils namespace which may be used from generated code. - if 'py2tf_util' not in namespace: - namespace['py2tf_utils'] = utils - elif namespace['py2tf_utils'] != utils: - raise ValueError( - 'The module name py2tf_utils is reserved and may not be used.') + namespace = inspect_utils.getnamespace(f) + _add_self_references(namespace, conversion_map.api_module) namer = conversion_map.new_namer(namespace) + ctx = context.EntityContext( namer=namer, source_code=source, @@ -200,8 +253,10 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, namespace=namespace, arg_values=arg_values, arg_types=arg_types, - recursive=conversion_map.recursive) - node = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + owner_type=owner_type, + recursive=conversion_map.recursive, + type_annotation_func=type_hints.set_element_type) + node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -212,7 +267,10 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, node.name = new_name conversion_map.update_name_map(namer) - return node, new_name + # TODO(mdan): Use this at compilation. + conversion_map.additional_imports.update(deps) + + return node, new_name, namespace def _static_analysis_pass(node, ctx): @@ -250,12 +308,19 @@ def node_to_graph(node, ctx, nocompile_decorators): # to re-run the analysis. node = _static_analysis_pass(node, ctx) + + # TODO(mdan): Clean this up. + # Some intermediate analyses are not required, and some comments got orphaned. + # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? ctx.source_code = None - node = decorators.transform(node, nocompile_decorators) + node = ifexp.transform(node, ctx) + node, deps = decorators.transform(node, nocompile_decorators) node = break_statements.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids @@ -265,8 +330,10 @@ def node_to_graph(node, ctx, nocompile_decorators): ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) - node = for_loops.transform(node, ctx) - # for_loops may insert new global references. + node = single_return.transform(node, ctx) + + node = _static_analysis_pass(node, ctx) + node = lists.transform(node, ctx) node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) @@ -276,7 +343,8 @@ def node_to_graph(node, ctx, nocompile_decorators): # control_flow may create new symbols and change scopes. node = _static_analysis_pass(node, ctx) - node = logical_expressions.transform(node) + node = logical_expressions.transform(node, ctx) node = side_effect_guards.transform(node, ctx) + node = name_scopes.transform(node, ctx) - return node + return node, deps diff --git a/tensorflow/contrib/py2tf/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py similarity index 60% rename from tensorflow/contrib/py2tf/impl/conversion_test.py rename to tensorflow/contrib/autograph/impl/conversion_test.py index 3888958f19b9fa13b759924c5188722e500e30a1..f0b597c12fd3dfb503a68e46d0e041c0e67d65ad 100644 --- a/tensorflow/contrib/py2tf/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -20,26 +20,42 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.impl import conversion +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api +from tensorflow.contrib.autograph.impl import conversion +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ConversionTest(test.TestCase): + def _simple_conversion_map(self): + return conversion.ConversionMap(True, (), (), api) + + def test_is_whitelisted_for_graph(self): + + def test_fn(): + return constant_op.constant(1) + + self.assertFalse(conversion.is_whitelisted_for_graph(test_fn)) + self.assertTrue(conversion.is_whitelisted_for_graph(utils)) + self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) + def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = self._simple_conversion_map() conversion.entity_to_graph('dummy', conversion_map, None, None) def test_entity_to_graph_callable(self): - + b = 2 def f(a): - return a + return a + b - conversion_map = conversion.ConversionMap(True, (), ()) - ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) + conversion_map = self._simple_conversion_map() + ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) - self.assertEqual('tf__f', new_name) + self.assertEqual('tf__f', name) + self.assertTrue(ns['b'] is b) def test_entity_to_graph_call_tree(self): @@ -49,14 +65,17 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = self._simple_conversion_map() conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) self.assertTrue(g in conversion_map.dependency_cache) self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) + # need the extra .body[0] in order to step past the with tf.name_scope('f') + # that is added automatically self.assertEqual( - 'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id) + 'tf__g', + conversion_map.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) diff --git a/tensorflow/contrib/py2tf/impl/naming.py b/tensorflow/contrib/autograph/impl/naming.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming.py rename to tensorflow/contrib/autograph/impl/naming.py index 51326091de13715c32d0a79279f1d3274e48ad10..1facaa0ca0ebcc6d4281e7c92a462ceeb00b453a 100644 --- a/tensorflow/contrib/py2tf/impl/naming.py +++ b/tensorflow/contrib/autograph/impl/naming.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import qual_names class Namer(object): diff --git a/tensorflow/contrib/py2tf/impl/naming_test.py b/tensorflow/contrib/autograph/impl/naming_test.py similarity index 98% rename from tensorflow/contrib/py2tf/impl/naming_test.py rename to tensorflow/contrib/autograph/impl/naming_test.py index beb4e54937bbb91b19157c9b9e3c528353206c62..73fc0894655cb49e4f61bf8ca51995b06feb3072 100644 --- a/tensorflow/contrib/py2tf/impl/naming_test.py +++ b/tensorflow/contrib/autograph/impl/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.impl import naming +from tensorflow.contrib.autograph.impl import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..efb8d441dd839bd34dcbc18f701c7993a2f03906 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -0,0 +1,53 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "operators", + srcs = [ + "__init__.py", + "control_flow.py", + "data_structures.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "data_structures_test", + srcs = ["data_structures_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "control_flow_test", + srcs = ["control_flow_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04b4734551d3227a1c611d668f006a157c2c2dd3 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module implements operators that we overload. + +Note that "operator" is used loosely here, and includes control structures like +conditionals and loops, implemented in functional form, using for example +closures for the body. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# TODO(mdan): Add a container for implementation-specific toggles (throughout). + +from tensorflow.contrib.autograph.operators.control_flow import for_loop +from tensorflow.contrib.autograph.operators.control_flow import while_loop diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d8b0d593e5372942ca6423d10022f0f56d78ce --- /dev/null +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow statements: loops, conditionals, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_math_ops + +# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature. +# TODO(mdan): Rename arguments to match the AST names. + + +def for_loop(iterated, extra_cond, loop_body, init_state): + """Functional form of a for statement. + + The loop operates on a so-called state, which includes all symbols that are + variant across loop iterations, excluding the iterate. In what follows we + refer to state as either a tuple of entities that represent an actual state, + or a list of arguments of the corresponding types. + + Args: + iterated: The entity being iterated over. + extra_cond: Callable with the state as arguments, and boolean return type. + An additionnal loop condition. + loop_body: Callable with the iterate and the state as arguments, and + state as return type. The actual loop body. + init_state: Tuple containing the initial state. + + Returns: + Tuple containing the final state. + """ + if tensor_util.is_tensor(iterated): + return _known_len_for_loop(iterated, extra_cond, loop_body, init_state) + elif isinstance(iterated, dataset_ops.Dataset): + return _dataset_for_loop(iterated, extra_cond, loop_body, init_state) + else: + return _py_for_loop(iterated, extra_cond, loop_body, init_state) + + +def _py_for_loop(iterated, extra_cond, loop_body, init_state): + """Overload of for_loop that executes a Python for loop.""" + state = init_state + for iterate in iterated: + if not extra_cond(*state): + break + state = loop_body(iterate, *state) + + # TODO(mdan): Remove this special case. + if len(state) == 1: + return state[0] + return state + + +def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): + """Overload of for_loop that iterates over objects that define a length.""" + n = builtins.dynamic_len(iterated) + + def while_body(iterate_index, *state): + iterate = iterated[iterate_index] + new_state = loop_body(iterate, *state) + return (iterate_index + 1,) + new_state + + def while_cond(iterate_index, *state): + return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state)) + + results = while_loop( + while_cond, + while_body, + init_state=(0,) + init_state, + extra_deps=(iterated,), + opts=dict(maximum_iterations=n)) + # Dropping the iteration index because it's not syntactically visible. + results = results[1:] + + # TODO(mdan): Remove this special case. + if len(results) == 1: + return results[0] + return results + + +def _dataset_for_loop(ds, extra_cond, loop_body, init_state): + """Overload of for_loop that iterates over TF Datasets.""" + # Because Datsets only expose get_next, in the style of Python iterators, + # we are forced to unpack the loop as: + # + # epoch_number, iterate = ds.get_next() + # while epoch_number < 2: + # + # epoch_number, iterate = ds.get_next() + epoch_numbers = dataset_ops.Dataset.range(2) + def tag_with(ds, tag): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(tag).repeat(), ds)) + ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i)) + + iterator = ds_with_epoch.make_initializable_iterator() + with ops.control_dependencies((iterator.initializer,)): + epoch_number, iterate = iterator.get_next() + + def while_body(epoch_number, iterate, *state): + new_state = loop_body(iterate, *state) + epoch_number, iterate = iterator.get_next() + return (epoch_number, iterate) + new_state + + def while_cond(epoch_number, iterate, *state): + del iterate + return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state)) + + results = while_loop( + while_cond, + while_body, + init_state=(epoch_number, iterate) + init_state, + extra_deps=()) + # Dropping the epoch number and iterate because they are not not syntactically + # visible. + results = results[2:] + + # TODO(mdan): Remove this special case. + if len(results) == 1: + return results[0] + return results + + +def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): + """Functional form of a while statement. + + The loop operates on a so-called state, which includes all symbols that are + variant across loop iterations. In what follows we refer to state as either + a tuple of entities that represent an actual state, or a list of arguments + of the corresponding types. + + Args: + loop_cond: Callable with the state as arguments, and boolean return type. + The loop condition. + loop_body: Callable with the state as arguments, and state as return type. + The actual loop body. + init_state: Tuple containing the initial state. + extra_deps: Tuple containing additional entities on which the loop may + depend, such as loop invariants referenced by loop_cond. Used + exclusively for dispatch control. + opts: Optional dict of extra loop parameters. + + Returns: + Tuple containing the final state. + """ + # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch. + # That could be somethins as simple as a collection of dispatch rules, with + # some prioritization. + if any(tensor_util.is_tensor(v) for v in init_state + extra_deps): + return _tf_while_loop(loop_cond, loop_body, init_state, opts) + else: + return _py_while_loop(loop_cond, loop_body, init_state, opts) + + +def _tf_while_loop(loop_cond, loop_body, init_state, opts): + """Overload of while_loop that stages a TF while_loop.""" + if opts is None: + opts = {} + return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts) + + +def _py_while_loop(loop_cond, loop_body, init_state, opts): + """Overload of while_loop that executes a Python while loop.""" + del opts + state = init_state + while loop_cond(*state): + state = loop_body(*state) + return state + + +def if_stmt(cond, body, orelse): + """Functional form of an if statement. + + Args: + cond: Boolean. + body: Callable with no arguments, and outputs of the positive (if) branch + as return type. + orelse: Callable with no arguments, and outputs of the negative (else) + branch as return type. + + Returns: + Tuple containing the statement outputs. + """ + if tensor_util.is_tensor(cond): + return _tf_if_stmt(cond, body, orelse) + else: + return _py_if_stmt(cond, body, orelse) + + +def _tf_if_stmt(cond, body, orelse): + """Overload of if_stmt that stages a TF cond.""" + return control_flow_ops.cond(cond, body, orelse) + + +def _py_if_stmt(cond, body, orelse): + """Overload of if_stmt that executes a Python if statement.""" + return body() if cond else orelse() diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0cd0bfa82bb052d55dfe30f8700fc33a794a59f --- /dev/null +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for control_flow module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import control_flow +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ForLoopTest(test.TestCase): + + def test_tensor(self): + s = control_flow.for_loop( + constant_op.constant([1, 2, 3, 4]), + extra_cond=lambda s: True, + loop_body=lambda i, s: (s + i,), + init_state=(0,)) + with self.test_session() as sess: + self.assertEqual((10,), sess.run(s)) + + def test_python(self): + s = control_flow.for_loop( + range(5), + extra_cond=lambda s: True, + loop_body=lambda i, s: (s + i,), + init_state=(0,)) + self.assertEqual(10, s) + + def test_dataset(self): + to_int32 = lambda i: math_ops.cast(i, dtypes.int32) + s = control_flow.for_loop( + dataset_ops.Dataset.range(5).map(to_int32), + extra_cond=lambda s: True, + loop_body=lambda i, s: (s + i,), + init_state=(0,)) + with self.test_session() as sess: + self.assertEqual((10,), sess.run(s)) + + +class WhileLoopTest(test.TestCase): + + def test_tensor(self): + n = constant_op.constant(5) + results = control_flow.while_loop( + loop_cond=lambda i, s: i < n, + loop_body=lambda i, s: (i + 1, s + i,), + init_state=(0, 0), + extra_deps=(n,)) + with self.test_session() as sess: + self.assertEqual((5, 10), sess.run(results)) + + def test_python(self): + n = 5 + results = control_flow.while_loop( + loop_cond=lambda i, s: i < n, + loop_body=lambda i, s: (i + 1, s + i), + init_state=(0, 0), + extra_deps=(n,)) + self.assertEqual((5, 10), results) + + +class IfStmtTest(test.TestCase): + + def test_tensor(self): + def test_if_stmt(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: 1, + orelse=lambda: -1) + with self.test_session() as sess: + self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True)))) + self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False)))) + + def test_python(self): + self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1)) + self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..c862306baa9e8114a71a26323ddcbd35c8592c55 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -0,0 +1,56 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to data structures: list append, subscripts, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import tensor_array_ops + +# TODO(mdan): Add support for TensorList once functional. +# TODO(mdan): Add primitives for empty list, list with elements. + + +def append(target, element): + """The list append function. + + Note: it is unspecified where target will be mutated or not. If target is + a TensorFlow entity, it will not be typically mutated. If target is a plain + list, it will be. In general, if the target is mutated then the return value + should point to the original entity. + + Args: + target: An entity that supports append semantics. + element: The element to append. + + Returns: + Same as target, after the append was performed. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(target, element) + else: + return _py_append(target, element) + + +def _tf_tensorarray_append(target, element): + """Overload of append that stages a TensorArray write at the last position.""" + return target.write(target.size(), element) + + +def _py_append(target, element): + """Overload of append that executes a Python list append.""" + target.append(element) + return target diff --git a/tensorflow/contrib/py2tf/utils/printing_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py similarity index 52% rename from tensorflow/contrib/py2tf/utils/printing_test.py rename to tensorflow/contrib/autograph/operators/data_structures_test.py index 2070deb304d8df2433fb9a95ae36d48415578482..577d28c34da39f1216669513c29a00ac07bec126 100644 --- a/tensorflow/contrib/py2tf/utils/printing_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -12,41 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for printing module.""" +"""Tests for data_structures module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys +from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test -import six -from tensorflow.contrib.py2tf.utils import printing -from tensorflow.python.platform import test +class AppendTest(test.TestCase): + def test_tf_tensorarray(self): + l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) + l1 = data_structures.append(l, 1) + l2 = data_structures.append(l1, 2) + with self.test_session() as sess: + self.assertAllEqual(sess.run(l1.stack()), [1]) + self.assertAllEqual(sess.run(l2.stack()), [1, 2]) -class ContextManagersTest(test.TestCase): - - def test_call_print_tf(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(printing.call_print('test message', 1)) - self.assertEqual(out_capturer.getvalue(), 'test message 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_call_print_py_func(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(printing.call_print('test message', [1, 2])) - self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') - finally: - sys.stdout = sys.__stdout__ + def test_python(self): + l = [] + self.assertAllEqual(data_structures.append(l, 1), [1]) + self.assertAllEqual(data_structures.append(l, 2), [1, 2]) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD similarity index 81% rename from tensorflow/contrib/py2tf/pyct/BUILD rename to tensorflow/contrib/autograph/pyct/BUILD index e3c0da4b10f9ffbee1b2a906b64d4762f41d97b4..796ab445c74128e1123e24b67c288e0e3c5ca24c 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -24,6 +24,7 @@ py_library( "ast_util.py", "compiler.py", "context.py", + "inspect_utils.py", "parser.py", "pretty_printer.py", "qual_names.py", @@ -65,6 +66,18 @@ py_test( name = "compiler_test", srcs = ["compiler_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + +py_test( + name = "inspect_utils_test", + srcs = ["inspect_utils_test.py"], + srcs_version = "PY2AND3", deps = [ ":pyct", "//tensorflow/python:client_testlib", @@ -112,3 +125,14 @@ py_test( "@gast_archive//:gast", ], ) + +py_test( + name = "transformer_test", + srcs = ["transformer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) diff --git a/tensorflow/contrib/py2tf/pyct/__init__.py b/tensorflow/contrib/autograph/pyct/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/__init__.py rename to tensorflow/contrib/autograph/pyct/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py similarity index 92% rename from tensorflow/contrib/py2tf/pyct/anno.py rename to tensorflow/contrib/autograph/pyct/anno.py index 7a0528b6d0b65b6604930b7a13d8493af9d61f02..cc4a7edf02ed7556c9a552d8730e4c7875038c83 100644 --- a/tensorflow/contrib/py2tf/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -70,3 +70,8 @@ def delanno(node, key, field_name='___pyct_anno'): if not annotations: delattr(node, field_name) node._fields = tuple(f for f in node._fields if f != field_name) + + +def copyanno(from_node, to_node, key, field_name='___pyct_anno'): + if hasanno(from_node, key, field_name): + setanno(to_node, key, getanno(from_node, key, field_name), field_name) diff --git a/tensorflow/contrib/py2tf/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py similarity index 77% rename from tensorflow/contrib/py2tf/pyct/anno_test.py rename to tensorflow/contrib/autograph/pyct/anno_test.py index ff40bfe1f50ae731648afdf509c26c3a70d3f6cb..1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d 100644 --- a/tensorflow/contrib/py2tf/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -20,10 +20,13 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno from tensorflow.python.platform import test +# TODO(mdan): Consider strong types instead of primitives. + + class AnnoTest(test.TestCase): def test_basic(self): @@ -42,6 +45,17 @@ class AnnoTest(test.TestCase): with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + def test_copyanno(self): + node_1 = ast.Name() + anno.setanno(node_1, 'foo', 3) + + node_2 = ast.Name() + anno.copyanno(node_1, node_2, 'foo') + anno.copyanno(node_1, node_2, 'bar') + + self.assertTrue(anno.hasanno(node_2, 'foo')) + self.assertFalse(anno.hasanno(node_2, 'bar')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py similarity index 84% rename from tensorflow/contrib/py2tf/pyct/ast_util.py rename to tensorflow/contrib/autograph/pyct/ast_util.py index f916775b9cf3cec960ec2896c334f1d737862205..4a70bab4402a940dec6a8b183daf7406a7e34131 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -22,13 +22,13 @@ import ast import gast -from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.autograph.pyct import anno class CleanCopier(gast.NodeVisitor): """Copy AST nodes. - The copied nodes will ignore almost all fields that prefixed by '__'. + The copied nodes will ignore almost all fields that are prefixed by '__'. Exceptions make some annotations. """ @@ -84,7 +84,10 @@ class SymbolRenamer(gast.NodeTransformer): return self._process(node) def visit_Attribute(self, node): - return self._process(node) + if anno.hasanno(node, anno.Basic.QN): + return self._process(node) + # Attributes of dynamic objects will not have a QN. + return self.generic_visit(node) def rename_symbols(node, name_map): @@ -94,3 +97,12 @@ def rename_symbols(node, name_map): elif isinstance(node, tuple): return tuple(renamer.visit(n) for n in node) return renamer.visit(node) + + +def keywords_to_dict(keywords): + keys = [] + values = [] + for kw in keywords: + keys.append(gast.Str(kw.arg)) + values.append(kw.value) + return gast.Dict(keys=keys, values=values) diff --git a/tensorflow/contrib/py2tf/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py similarity index 72% rename from tensorflow/contrib/py2tf/pyct/ast_util_test.py rename to tensorflow/contrib/autograph/pyct/ast_util_test.py index e0b00c178168f96e656c57cc75a76e6da8af1d8a..8faf92c705d997db298dbb1115981fd9da26372d 100644 --- a/tensorflow/contrib/py2tf/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.python.platform import test @@ -33,15 +35,15 @@ class AstUtilTest(test.TestCase): ast.Name('b', ast.Load()), ast.Attribute(ast.Name('b', None), 'c', ast.Store()), ast.Attribute( - ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', - None) + ast.Attribute(ast.Name('b', None), 'c', ast.Load()), 'd', None) ], None) node = qual_names.resolve(node) node = ast_util.rename_symbols( - node, - { - qual_names.QN('a'): qual_names.QN('renamed_a'), - qual_names.QN('b.c'): qual_names.QN('renamed_b_c'), + node, { + qual_names.QN('a'): + qual_names.QN('renamed_a'), + qual_names.QN(qual_names.QN('b'), attr='c'): + qual_names.QN('renamed_b_c'), }) self.assertEqual(node.elts[0].id, 'renamed_a') @@ -74,6 +76,17 @@ class AstUtilTest(test.TestCase): self.assertFalse(ret is new_node.body[0]) self.assertFalse(hasattr(new_node.body[0], '__foo')) + def test_keywords_to_dict(self): + keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords + d = ast_util.keywords_to_dict(keywords) + # Make sure we generate a usable dict node by attaching it to a variable and + # compiling everything. + output = parser.parse_str('b = 3') + output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) + result, _ = compiler.ast_to_object(output) + self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) + print(d) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/autograph/pyct/compiler.py similarity index 90% rename from tensorflow/contrib/py2tf/pyct/compiler.py rename to tensorflow/contrib/autograph/pyct/compiler.py index 51cf6930e8bcb3728ee55bf5d4781f01a5ef73bd..24c4517afa89147101f80af3ef60237132c1144c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/autograph/pyct/compiler.py @@ -31,7 +31,7 @@ import astor import gast -def ast_to_source(node, indentation): +def ast_to_source(node, indentation=' '): """Return the source code of given AST.""" if isinstance(node, gast.AST): node = gast.gast_to_ast(node) @@ -39,7 +39,10 @@ def ast_to_source(node, indentation): astor.string_repr.pretty_string) generator.visit(node) generator.result.append('\n') - return astor.source_repr.pretty_source(generator.result).lstrip() + # In some versions of Python, literals may appear as actual values. This + # ensures everything is string. + code = map(str, generator.result) + return astor.source_repr.pretty_source(code).lstrip() def ast_to_object( diff --git a/tensorflow/contrib/py2tf/pyct/compiler_test.py b/tensorflow/contrib/autograph/pyct/compiler_test.py similarity index 83% rename from tensorflow/contrib/py2tf/pyct/compiler_test.py rename to tensorflow/contrib/autograph/pyct/compiler_test.py index c1f84238efa7dd6fc0748748a2cb4f074572b4c6..98cdc1506b6aced603df99662f1468687a55f92c 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler_test.py +++ b/tensorflow/contrib/autograph/pyct/compiler_test.py @@ -22,12 +22,29 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect class CompilerTest(test.TestCase): + def test_parser_compile_idempotent(self): + + def test_fn(x): + a = True + b = '' + if a: + b = x + 1 + return b + + self.assertEqual( + textwrap.dedent(tf_inspect.getsource(test_fn)), + tf_inspect.getsource( + compiler.ast_to_object( + parser.parse_entity(test_fn)[0].body[0])[0].test_fn)) + def test_ast_to_source(self): node = gast.If( test=gast.Num(1), diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py similarity index 82% rename from tensorflow/contrib/py2tf/pyct/context.py rename to tensorflow/contrib/autograph/pyct/context.py index fef74ebefa290369c7310af6d7e4faeef44d9aee..b34015cfd2888f0dbeb6492b9e7335d561bf4763 100644 --- a/tensorflow/contrib/py2tf/pyct/context.py +++ b/tensorflow/contrib/autograph/pyct/context.py @@ -22,6 +22,8 @@ from __future__ import print_function class EntityContext(object): """Contains information about an entity, like source code. + In general, objects of this class should be considered immutable. + Attributes: namer: Namer that matches the contract of all converters. source_code: The entity's source code. @@ -30,14 +32,18 @@ class EntityContext(object): (excluding parameters). arg_values: Dict[str->*], containing parameter values, if known. arg_types: Dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. """ + # TODO(mdan): Remove the default and update tests. def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, recursive): + arg_types, owner_type, recursive, type_annotation_func=None): self.namer = namer self.source_code = source_code self.source_file = source_file self.namespace = namespace self.arg_values = {} if arg_values is None else arg_values self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type self.recursive = recursive + self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..63361cc4f2557d22800072d90a51b7e4ddab34ab --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py @@ -0,0 +1,150 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Live entity inspection utilities. + +This module contains whatever inspect doesn't offer out of the box. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import types + +import six + +from tensorflow.python.util import tf_inspect + + +def isbuiltin(f): + # Note these return false for isinstance(f, types.BuiltinFunctionType) so we + # need to specifically check for them. + if f in (range, int, float): + return True + if isinstance(f, types.BuiltinFunctionType): + return True + if tf_inspect.isbuiltin(f): + return True + return False + + +def getnamespace(f): + """Returns the complete namespace of a function. + + Namespace is defined here as the mapping of all non-local variables to values. + This includes the globals and the closure variables. Note that this captures + the entire globals collection of the function, and may contain extra symbols + that it does not actually use. + + Args: + f: User defined function. + Returns: + A dict mapping symbol names to values. + """ + namespace = dict(six.get_function_globals(f)) + closure = six.get_function_closure(f) + freevars = six.get_function_code(f).co_freevars + if freevars and closure: + for name, cell in zip(freevars, closure): + namespace[name] = cell.cell_contents + return namespace + + +def getdefiningclass(m, owner_class): + """Resolves the class (e.g. one of the superclasses) that defined a method.""" + m = six.get_unbound_function(m) + last_defining = owner_class + for superclass in tf_inspect.getmro(owner_class): + if hasattr(superclass, m.__name__): + superclass_m = getattr(superclass, m.__name__) + if six.get_unbound_function(superclass_m) == m: + last_defining = superclass + return last_defining + + +def getmethodclass(m): + """Resolves a function's owner, e.g. a method's class. + + Note that this returns the object that the function was retrieved from, not + necessarily the class where it was defined. + + This function relies on Python stack frame support in the interpreter, and + has the same limitations that inspect.currentframe. + + Limitations. This function will only work correctly if the owned class is + visible in the caller's global or local variables. + + Args: + m: A user defined function + + Returns: + The class that this function was retrieved from, or None if the function + is not an object or class method, or the class that owns the object or + method is not visible to m. + + Raises: + ValueError: if the class could not be resolved for any unexpected reason. + """ + + # Callable objects: return their own class. + if (not hasattr(m, '__name__') and hasattr(m, '__class__') and + hasattr(m, '__call__')): + if isinstance(m.__class__, six.class_types): + return m.__class__ + + # Instance method and class methods: should be bound to a non-null "self". + # If self is a class, then it's a class method. + if hasattr(m, '__self__'): + if m.__self__: + if tf_inspect.isclass(m.__self__): + return m.__self__ + return type(m.__self__) + + # Class, static and unbound methods: search all defined classes in any + # namespace. This is inefficient but more robust method. + owners = [] + caller_frame = tf_inspect.currentframe().f_back + try: + # TODO(mdan): This doesn't consider cell variables. + # TODO(mdan): This won't work if the owner is hidden inside a container. + # Cell variables may be pulled using co_freevars and the closure. + for v in itertools.chain(caller_frame.f_locals.values(), + caller_frame.f_globals.values()): + if hasattr(v, m.__name__): + candidate = getattr(v, m.__name__) + # Py2 methods may be bound or unbound, extract im_func to get the + # underlying function. + if hasattr(candidate, 'im_func'): + candidate = candidate.im_func + if hasattr(m, 'im_func'): + m = m.im_func + if candidate is m: + owners.append(v) + finally: + del caller_frame + + if owners: + if len(owners) == 1: + return owners[0] + + # If multiple owners are found, and are not subclasses, raise an error. + owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners) + for o in owner_types: + if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)): + return o + raise ValueError('Found too many owners of %s: %s' % (m, owners)) + + return None diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cf841dae814f64583bc43a2e110f1dcf5c0d7c1f --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -0,0 +1,270 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for unspect_utils module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +import six + +from tensorflow.contrib.autograph.pyct import inspect_utils +from tensorflow.python.platform import test + + +def decorator(f): + return f + + +def function_decorator(): + def dec(f): + return f + return dec + + +def wrapping_decorator(): + def dec(f): + def replacement(*_): + return None + + @wraps(f) + def wrapper(*args, **kwargs): + return replacement(*args, **kwargs) + return wrapper + return dec + + +class TestClass(object): + + def member_function(self): + pass + + @decorator + def decorated_member(self): + pass + + @function_decorator() + def fn_decorated_member(self): + pass + + @wrapping_decorator() + def wrap_decorated_member(self): + pass + + @staticmethod + def static_method(): + pass + + @classmethod + def class_method(cls): + pass + + +def free_function(): + pass + + +def factory(): + return free_function + + +def free_factory(): + def local_function(): + pass + return local_function + + +class InspectUtilsTest(test.TestCase): + + def test_getnamespace_globals(self): + ns = inspect_utils.getnamespace(factory) + self.assertEqual(ns['free_function'], free_function) + + def test_getnamespace_hermetic(self): + + # Intentionally hiding the global function to make sure we don't overwrite + # it in the global namespace. + free_function = object() # pylint:disable=redefined-outer-name + + def test_fn(): + return free_function + + ns = inspect_utils.getnamespace(test_fn) + globs = six.get_function_globals(test_fn) + self.assertTrue(ns['free_function'] is free_function) + self.assertFalse(globs['free_function'] is free_function) + + def test_getnamespace_locals(self): + + def called_fn(): + return 0 + + closed_over_list = [] + closed_over_primitive = 1 + + def local_fn(): + closed_over_list.append(1) + local_var = 1 + return called_fn() + local_var + closed_over_primitive + + ns = inspect_utils.getnamespace(local_fn) + self.assertEqual(ns['called_fn'], called_fn) + self.assertEqual(ns['closed_over_list'], closed_over_list) + self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) + self.assertTrue('local_var' not in ns) + + def test_getmethodclass(self): + + self.assertEqual( + inspect_utils.getmethodclass(free_function), None) + self.assertEqual( + inspect_utils.getmethodclass(free_factory()), None) + + self.assertEqual( + inspect_utils.getmethodclass(TestClass.member_function), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.fn_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.wrap_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.static_method), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.class_method), + TestClass) + + test_obj = TestClass() + self.assertEqual( + inspect_utils.getmethodclass(test_obj.member_function), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.fn_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.static_method), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.class_method), + TestClass) + + def test_getmethodclass_locals(self): + + def local_function(): + pass + + class LocalClass(object): + + def member_function(self): + pass + + @decorator + def decorated_member(self): + pass + + @function_decorator() + def fn_decorated_member(self): + pass + + @wrapping_decorator() + def wrap_decorated_member(self): + pass + + self.assertEqual( + inspect_utils.getmethodclass(local_function), None) + + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.member_function), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.fn_decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), + LocalClass) + + test_obj = LocalClass() + self.assertEqual( + inspect_utils.getmethodclass(test_obj.member_function), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.fn_decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), + LocalClass) + + def test_getmethodclass_callables(self): + class TestCallable(object): + + def __call__(self): + pass + + c = TestCallable() + self.assertEqual(inspect_utils.getmethodclass(c), TestCallable) + + def test_getdefiningclass(self): + class Superclass(object): + + def foo(self): + pass + + def bar(self): + pass + + class Subclass(Superclass): + + def foo(self): + pass + + def baz(self): + pass + + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) + + def test_isbuiltin(self): + self.assertTrue(inspect_utils.isbuiltin(range)) + self.assertTrue(inspect_utils.isbuiltin(float)) + self.assertTrue(inspect_utils.isbuiltin(int)) + self.assertTrue(inspect_utils.isbuiltin(len)) + self.assertFalse(inspect_utils.isbuiltin(function_decorator)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/contrib/autograph/pyct/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..c961efa892df6a21804dae8f52ef64bf99cd409e --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/parser.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converting code to AST. + +Adapted from Tangent. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +import gast + +from tensorflow.python.util import tf_inspect + + +def parse_entity(entity): + """Returns the AST of given entity.""" + source = tf_inspect.getsource(entity) + source = textwrap.dedent(source) + return parse_str(source), source + + +def parse_str(src): + """Returns the AST of given piece of code.""" + return gast.parse(src) + + +def parse_expression(src): + """Returns the AST of given identifier. + + Args: + src: A piece of code that represents a single Python expression + Returns: + A gast.AST object. + Raises: + ValueError: if src does not consist of a single Expression. + """ + node = parse_str(src) + assert isinstance(node, gast.Module) + if len(node.body) != 1 and not isinstance(node.body[0], gast.Expr): + raise ValueError( + 'Expected a single expression, found instead %s' % node.body) + return node.body[0].value diff --git a/tensorflow/contrib/py2tf/pyct/parser_test.py b/tensorflow/contrib/autograph/pyct/parser_test.py similarity index 80% rename from tensorflow/contrib/py2tf/pyct/parser_test.py rename to tensorflow/contrib/autograph/pyct/parser_test.py index f35dfa04c70dc191078248c32f9a04d28133129a..007a4c6fb0393b7235808478d55b3ffa469f85d0 100644 --- a/tensorflow/contrib/py2tf/pyct/parser_test.py +++ b/tensorflow/contrib/autograph/pyct/parser_test.py @@ -20,28 +20,33 @@ from __future__ import print_function import textwrap -from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test -def f(x): - return x + 1 - - class ParserTest(test.TestCase): def test_parse_entity(self): + + def f(x): + return x + 1 + mod, _ = parser.parse_entity(f) self.assertEqual('f', mod.body[0].name) def test_parse_str(self): mod = parser.parse_str( textwrap.dedent(""" - def f(x): - return x + 1 + def f(x): + return x + 1 """)) self.assertEqual('f', mod.body[0].name) + def test_parse_expression(self): + node = parser.parse_expression('a.b') + self.assertEqual('a', node.value.id) + self.assertEqual('b', node.attr) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer.py b/tensorflow/contrib/autograph/pyct/pretty_printer.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/pretty_printer.py rename to tensorflow/contrib/autograph/pyct/pretty_printer.py diff --git a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py similarity index 96% rename from tensorflow/contrib/py2tf/pyct/pretty_printer_test.py rename to tensorflow/contrib/autograph/pyct/pretty_printer_test.py index 81e3f47b80b6cb3bb7ba9f4a1787d03df4151a99..0cb48f35760b7b2655eb5cf73017b70e28dae219 100644 --- a/tensorflow/contrib/py2tf/pyct/pretty_printer_test.py +++ b/tensorflow/contrib/autograph/pyct/pretty_printer_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import ast -from tensorflow.contrib.py2tf.pyct import pretty_printer +from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py new file mode 100644 index 0000000000000000000000000000000000000000..583cf7ecd7bce31c55de58361ab5295abb5d6707 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -0,0 +1,228 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for manipulating qualified names. + +A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite +(e.g. 'foo.bar') syntactic symbols. + +This is *not* related to the __qualname__ attribute used by inspect, which +refers to scopes. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import gast + +from tensorflow.contrib.autograph.pyct import anno + + +class Symbol(collections.namedtuple('Symbol', ['name'])): + """Represents a Python symbol.""" + + +class StringLiteral(collections.namedtuple('StringLiteral', ['value'])): + """Represents a Python string literal.""" + + def __str__(self): + return '\'%s\'' % self.value + + def __repr__(self): + return str(self) + + +class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])): + """Represents a Python numeric literal.""" + + def __str__(self): + return '%s' % self.value + + def __repr__(self): + return str(self) + + +# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans. +class QN(object): + """Represents a qualified name.""" + + def __init__(self, base, attr=None, subscript=None): + if attr is not None and subscript is not None: + raise ValueError('A QN can only be either an attr or a subscript, not ' + 'both: attr={}, subscript={}.'.format(attr, subscript)) + self._has_attr = False + self._has_subscript = False + + if attr is not None: + if not isinstance(base, QN): + raise ValueError( + 'for attribute QNs, base must be a QN; got instead "%s"' % base) + if not isinstance(attr, str): + raise ValueError('attr may only be a string; got instead "%s"' % attr) + self._parent = base + # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now. + self.qn = (base, attr) + self._has_attr = True + + elif subscript is not None: + if not isinstance(base, QN): + raise ValueError('For subscript QNs, base must be a QN.') + self._parent = base + self.qn = (base, subscript) + self._has_subscript = True + + else: + if not isinstance(base, (str, StringLiteral, NumberLiteral)): + # TODO(mdan): Require Symbol instead of string. + raise ValueError( + 'For simple QNs, base must be a string or a Literal object.') + assert '.' not in base and '[' not in base and ']' not in base + self._parent = None + self.qn = (base,) + + def is_symbol(self): + return isinstance(self.qn[0], str) + + def is_composite(self): + return len(self.qn) > 1 + + def has_subscript(self): + return self._has_subscript + + def has_attr(self): + return self._has_attr + + @property + def parent(self): + if self._parent is None: + raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0]) + return self._parent + + @property + def support_set(self): + """Returns the set of simple symbols that this QN relies on. + + This would be the smallest set of symbols necessary for the QN to + statically resolve (assuming properties and index ranges are verified + at runtime). + + Examples: + 'a.b' has only one support symbol, 'a' + 'a[i]' has two roots, 'a' and 'i' + """ + # TODO(mdan): This might be the set of Name nodes in the AST. Track those? + roots = set() + if self.has_attr(): + roots.update(self.parent.support_set) + elif self.has_subscript(): + roots.update(self.parent.support_set) + roots.update(self.qn[1].support_set) + else: + roots.add(self) + return roots + + def __hash__(self): + return hash(self.qn + (self._has_attr, self._has_subscript)) + + def __eq__(self, other): + return (isinstance(other, QN) and self.qn == other.qn and + self.has_subscript() == other.has_subscript() and + self.has_attr() == other.has_attr()) + + def __str__(self): + if self.has_subscript(): + return str(self.qn[0]) + '[' + str(self.qn[1]) + ']' + if self.has_attr(): + return '.'.join(map(str, self.qn)) + else: + return str(self.qn[0]) + + def __repr__(self): + return str(self) + + def ssf(self): + """Simple symbol form.""" + ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn] + ssf_string = '' + for i in range(0, len(self.qn) - 1): + if self.has_subscript(): + delimiter = '_sub_' + else: + delimiter = '_' + ssf_string += ssfs[i] + delimiter + return ssf_string + ssfs[-1] + + def ast(self): + # The caller must adjust the context appropriately. + if self.has_subscript(): + return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()), + None) + if self.has_attr(): + return gast.Attribute(self.parent.ast(), self.qn[-1], None) + + base = self.qn[0] + if isinstance(base, str): + return gast.Name(base, None, None) + elif isinstance(base, StringLiteral): + return gast.Str(base.value) + elif isinstance(base, NumberLiteral): + return gast.Num(base.value) + else: + assert False, ('the constructor should prevent types other than ' + 'str, StringLiteral and NumberLiteral') + + +class QnResolver(gast.NodeTransformer): + """Annotates nodes with QN information. + + Note: Not using NodeAnnos to avoid circular dependencies. + """ + + def visit_Name(self, node): + node = self.generic_visit(node) + anno.setanno(node, anno.Basic.QN, QN(node.id)) + return node + + def visit_Attribute(self, node): + node = self.generic_visit(node) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr)) + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + s = node.slice + if not isinstance(s, gast.Index): + # TODO(mdan): Support range and multi-dimensional indices. + # Continuing silently because some demos use these. + return node + if isinstance(s.value, gast.Num): + subscript = QN(NumberLiteral(s.value.n)) + elif isinstance(s.value, gast.Str): + subscript = QN(StringLiteral(s.value.s)) + else: + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + if anno.hasanno(node.value, anno.Basic.QN): + anno.setanno(node, anno.Basic.QN, + QN(anno.getanno(node.value, anno.Basic.QN), + subscript=subscript)) + return node + + +def resolve(node): + return QnResolver().visit(node) diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py new file mode 100644 index 0000000000000000000000000000000000000000..264afd508cdb847315c486806b531dc1483ef622 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py @@ -0,0 +1,246 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for qual_names module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.qual_names import resolve +from tensorflow.python.platform import test + + +class QNTest(test.TestCase): + + def test_basic(self): + a = QN('a') + self.assertEqual(a.qn, ('a',)) + self.assertEqual(str(a), 'a') + self.assertEqual(a.ssf(), 'a') + self.assertEqual(a.ast().id, 'a') + self.assertFalse(a.is_composite()) + with self.assertRaises(ValueError): + _ = a.parent + + a_b = QN(a, attr='b') + self.assertEqual(a_b.qn, (a, 'b')) + self.assertEqual(str(a_b), 'a.b') + self.assertEqual(a_b.ssf(), 'a_b') + self.assertEqual(a_b.ast().value.id, 'a') + self.assertEqual(a_b.ast().attr, 'b') + self.assertTrue(a_b.is_composite()) + self.assertEqual(a_b.parent.qn, ('a',)) + + def test_subscripts(self): + a = QN('a') + b = QN('b') + a_sub_b = QN(a, subscript=b) + self.assertEqual(a_sub_b.qn, (a, b)) + self.assertEqual(str(a_sub_b), 'a[b]') + self.assertEqual(a_sub_b.ssf(), 'a_sub_b') + self.assertEqual(a_sub_b.ast().value.id, 'a') + self.assertEqual(a_sub_b.ast().slice.value.id, 'b') + self.assertTrue(a_sub_b.is_composite()) + self.assertTrue(a_sub_b.has_subscript()) + self.assertEqual(a_sub_b.parent.qn, ('a',)) + + c = QN('c') + b_sub_c = QN(b, subscript=c) + a_sub_b_sub_c = QN(a, subscript=b_sub_c) + self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c)) + self.assertTrue(a_sub_b.is_composite()) + self.assertTrue(a_sub_b_sub_c.is_composite()) + self.assertTrue(a_sub_b.has_subscript()) + self.assertTrue(a_sub_b_sub_c.has_subscript()) + self.assertEqual(b_sub_c.qn, (b, c)) + self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') + self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c') + self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a') + self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b') + self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c') + self.assertEqual(b_sub_c.ast().slice.value.id, 'c') + self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',)) + with self.assertRaises(ValueError): + QN('a', 'b') + + def test_equality(self): + a = QN('a') + a2 = QN('a') + a_b = QN(a, attr='b') + self.assertEqual(a2.qn, ('a',)) + with self.assertRaises(ValueError): + _ = a.parent + + a_b2 = QN(a, attr='b') + self.assertEqual(a_b2.qn, (a, 'b')) + self.assertEqual(a_b2.parent.qn, ('a',)) + + self.assertTrue(a2 == a) + self.assertFalse(a2 is a) + + self.assertTrue(a_b.parent == a) + self.assertTrue(a_b2.parent == a) + + self.assertTrue(a_b2 == a_b) + self.assertFalse(a_b2 is a_b) + self.assertFalse(a_b2 == a) + a_sub_b = QN(a, subscript='b') + a_sub_b2 = QN(a, subscript='b') + self.assertTrue(a_sub_b == a_sub_b2) + self.assertFalse(a_sub_b == a_b) + + def test_nested_attrs_subscripts(self): + a = QN('a') + b = QN('b') + c = QN('c') + b_sub_c = QN(b, subscript=c) + a_sub_b_sub_c = QN(a, subscript=b_sub_c) + + b_dot_c = QN(b, attr='c') + a_sub__b_dot_c = QN(a, subscript=b_dot_c) + + a_sub_b = QN(a, subscript=b) + a_sub_b__dot_c = QN(a_sub_b, attr='c') + + a_dot_b = QN(a, attr='b') + a_dot_b_sub_c = QN(a_dot_b, subscript=c) + + self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]') + self.assertEqual(str(a_sub__b_dot_c), 'a[b.c]') + self.assertEqual(str(a_sub_b__dot_c), 'a[b].c') + self.assertEqual(str(a_dot_b_sub_c), 'a.b[c]') + + self.assertNotEqual(a_sub_b_sub_c, a_sub__b_dot_c) + self.assertNotEqual(a_sub_b_sub_c, a_sub_b__dot_c) + self.assertNotEqual(a_sub_b_sub_c, a_dot_b_sub_c) + + self.assertNotEqual(a_sub__b_dot_c, a_sub_b__dot_c) + self.assertNotEqual(a_sub__b_dot_c, a_dot_b_sub_c) + + self.assertNotEqual(a_sub_b__dot_c, a_dot_b_sub_c) + + def test_hashable(self): + d = {QN('a'): 'a', QN('b'): 'b'} + self.assertEqual(d[QN('a')], 'a') + self.assertEqual(d[QN('b')], 'b') + self.assertTrue(QN('c') not in d) + + def test_literals(self): + a = QN('a') + a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b'))) + a_sub_b = QN(a, subscript=QN('b')) + + self.assertNotEqual(a_sub_str_b, a_sub_b) + self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b)) + + a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3))) + self.assertEqual(a_sub_three.ast().slice.value.n, 3) + + def test_support_set(self): + a = QN('a') + b = QN('b') + c = QN('c') + a_sub_b = QN(a, subscript=b) + a_dot_b = QN(a, attr='b') + a_dot_b_dot_c = QN(a_dot_b, attr='c') + a_dot_b_sub_c = QN(a_dot_b, subscript=c) + + self.assertSetEqual(a.support_set, set((a,))) + self.assertSetEqual(a_sub_b.support_set, set((a, b))) + self.assertSetEqual(a_dot_b.support_set, set((a,))) + self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,))) + self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c))) + + +class QNResolverTest(test.TestCase): + + def assertQNStringIs(self, node, qn_str): + self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) + + def test_resolve(self): + samples = """ + a + a.b + (c, d.e) + [f, (g.h.i)] + j(k, l) + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'a') + self.assertQNStringIs(nodes[1], 'a.b') + self.assertQNStringIs(nodes[2].elts[0], 'c') + self.assertQNStringIs(nodes[2].elts[1], 'd.e') + self.assertQNStringIs(nodes[3].elts[0], 'f') + self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') + self.assertQNStringIs(nodes[4].func, 'j') + self.assertQNStringIs(nodes[4].args[0], 'k') + self.assertQNStringIs(nodes[4].args[1], 'l') + + def test_subscript_resolve(self): + samples = """ + x[i] + x[i.b] + a.b[c] + a.b[x.y] + a[z[c]] + a[b[c[d]]] + a[b].c + a.b.c[d].e.f + a.b[c[d]].e.f + a.b[c[d.e.f].g].h + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + + self.assertQNStringIs(nodes[0], 'x[i]') + self.assertQNStringIs(nodes[1], 'x[i.b]') + self.assertQNStringIs(nodes[2], 'a.b[c]') + self.assertQNStringIs(nodes[3], 'a.b[x.y]') + self.assertQNStringIs(nodes[4], 'a[z[c]]') + self.assertQNStringIs(nodes[5], 'a[b[c[d]]]') + self.assertQNStringIs(nodes[6], 'a[b].c') + self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f') + self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f') + self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h') + + def test_function_calls(self): + samples = """ + a.b + a.b() + a().b + z[i] + z[i]() + z()[i] + """ + nodes = resolve(parser.parse_str(textwrap.dedent(samples))) + nodes = tuple(n.value for n in nodes.body) + self.assertQNStringIs(nodes[0], 'a.b') + self.assertQNStringIs(nodes[1].func, 'a.b') + self.assertQNStringIs(nodes[2].value.func, 'a') + self.assertQNStringIs(nodes[3], 'z[i]') + self.assertQNStringIs(nodes[4].func, 'z[i]') + self.assertQNStringIs(nodes[5].value.func, 'z') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD similarity index 80% rename from tensorflow/contrib/py2tf/pyct/static_analysis/BUILD rename to tensorflow/contrib/autograph/pyct/static_analysis/BUILD index fbfce18c60cca4b105e7de3c3ea7b9c3438f6b2a..83f3bafc4217649db6499566d548c1657428ad0b 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -25,7 +25,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "@gast_archive//:gast", ], ) @@ -34,9 +34,10 @@ py_test( name = "activity_test", srcs = ["activity_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", "@gast_archive//:gast", ], @@ -46,9 +47,10 @@ py_test( name = "live_values_test", srcs = ["live_values_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], ) @@ -59,7 +61,8 @@ py_test( srcs_version = "PY2AND3", deps = [ ":static_analysis", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py similarity index 100% rename from tensorflow/contrib/py2tf/pyct/static_analysis/__init__.py rename to tensorflow/contrib/autograph/pyct/static_analysis/__init__.py diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py similarity index 66% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 1c93e1603113d48176af7a97a0f37321e6f67586..2c14c2c8c23810c64446eb9e7ffc5402ce9a2298 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -22,9 +22,10 @@ import copy import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). @@ -70,13 +71,33 @@ class Scope(object): tuple(self.modified)) def copy_from(self, other): + """Recursively copies the contents of this scope from another scope.""" + if (self.parent is None) != (other.parent is None): + raise ValueError('cannot copy scopes of different structures') + if other.parent is not None: + self.parent.copy_from(other.parent) + self.isolated = other.isolated self.modified = copy.copy(other.modified) self.created = copy.copy(other.created) self.used = copy.copy(other.used) self.params = copy.copy(other.params) self.returned = copy.copy(other.returned) + @classmethod + def copy_of(cls, other): + if other.parent is not None: + parent = cls.copy_of(other.parent) + else: + parent = None + new_copy = cls(parent) + new_copy.copy_from(other) + return new_copy + def merge_from(self, other): + if (self.parent is None) != (other.parent is None): + raise ValueError('cannot merge scopes of different structures') + if other.parent is not None: + self.parent.merge_from(other.parent) self.modified |= other.modified self.created |= other.created self.used |= other.used @@ -112,18 +133,18 @@ class Scope(object): def mark_param(self, name): self.params.add(name) - def mark_creation(self, name): + def mark_creation(self, name, writes_create_symbol=False): if name.is_composite(): parent = name.parent if self.has(parent): - # This is considered mutation of the parent, not creation. - # TODO(mdan): Is that really so? - return + if not writes_create_symbol: + return else: raise ValueError('Unknown symbol "%s".' % parent) self.created.add(name) def mark_write(self, name): + """Marks the given symbol as modified in the current scope.""" self.modified.add(name) if self.isolated: self.mark_creation(name) @@ -141,19 +162,45 @@ class Scope(object): self.parent.mark_returned(name) -class ActivityAnalizer(transformer.Base): +class ActivityAnalyzer(transformer.Base): """Annotates nodes with local scope information. See Scope.""" def __init__(self, context, parent_scope): - super(ActivityAnalizer, self).__init__(context) + super(ActivityAnalyzer, self).__init__(context) self.scope = Scope(parent_scope) self._in_return_statement = False - def _track_symbol(self, node): + @property + def _in_constructor(self): + innermost = self.enclosing_entities[-1] + if len(self.enclosing_entities) > 1: + parent = self.enclosing_entities[-2] + return isinstance(parent, gast.ClassDef) and innermost.name == '__init__' + return False + + def _node_sets_self_attribute(self, node): + if anno.hasanno(node, anno.Basic.QN): + qn = anno.getanno(node, anno.Basic.QN) + # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. + if qn.has_attr and qn.parent.qn == ('self',): + return True + + def _track_symbol(self, + node, + composite_writes_alter_parent=False, + writes_create_symbol=False): + # A QN may be missing when we have an attribute (or subscript) on a function + # call. Example: a().b + if not anno.hasanno(node, anno.Basic.QN): + return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) + if qn.is_composite and composite_writes_alter_parent: + self.scope.mark_write(qn.parent) + if writes_create_symbol: + self.scope.mark_creation(qn, writes_create_symbol=True) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): @@ -182,7 +229,18 @@ class ActivityAnalizer(transformer.Base): def visit_Attribute(self, node): self.generic_visit(node) - self._track_symbol(node) + if self._in_constructor and self._node_sets_self_attribute(node): + self._track_symbol( + node, composite_writes_alter_parent=True, writes_create_symbol=True) + else: + self._track_symbol(node) + return node + + def visit_Subscript(self, node): + self.generic_visit(node) + # Subscript writes (e.g. a[b] = "value") are considered to modify + # both the element itself (a[b]) and its parent (a). + self._track_symbol(node, composite_writes_alter_parent=True) return node def visit_Print(self, node): @@ -224,21 +282,46 @@ class ActivityAnalizer(transformer.Base): # modifies the parent state causing the other child blocks to be # processed incorrectly. So we need to checkpoint the parent scope so that # each child sees the same context. - before_parent = Scope(None) - before_parent.copy_from(self.scope) + before_parent = Scope.copy_of(self.scope) after_children = [] for child, scope_name in children: self.scope.copy_from(before_parent) parent = self._process_block_node(parent, child, scope_name) - after_child = Scope(None) - after_child.copy_from(self.scope) + after_child = Scope.copy_of(self.scope) after_children.append(after_child) for after_child in after_children: self.scope.merge_from(after_child) return parent + def visit_FunctionDef(self, node): + if self.scope: + qn = QN(node.name) + self.scope.mark_write(qn) + current_scope = self.scope + body_scope = Scope(current_scope, isolated=True) + self.scope = body_scope + self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope) + self.scope = current_scope + return node + + def visit_With(self, node): + current_scope = self.scope + with_scope = Scope(current_scope, isolated=False) + self.scope = with_scope + self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope) + self.scope = current_scope + return node + def visit_If(self, node): + current_scope = self.scope + cond_scope = Scope(current_scope, isolated=False) + self.scope = cond_scope self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) + self.scope = current_scope + node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) @@ -253,7 +336,13 @@ class ActivityAnalizer(transformer.Base): return node def visit_While(self, node): + current_scope = self.scope + cond_scope = Scope(current_scope, isolated=False) + self.scope = cond_scope self.visit(node.test) + anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope) + self.scope = current_scope + node = self._process_parallel_blocks(node, ((node.body, NodeAnno.BODY_SCOPE), (node.orelse, NodeAnno.ORELSE_SCOPE))) @@ -267,4 +356,4 @@ class ActivityAnalizer(transformer.Base): def resolve(node, context, parent_scope=None): - return ActivityAnalizer(context, parent_scope).visit(node) + return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py similarity index 51% rename from tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index e1eb954a5efef4d6a00ac492e7c85394d54e28c9..ef79a295bfa3940705d2f341edd4eda74d7d7068 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -20,13 +20,13 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.qual_names import QN -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.qual_names import QN +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno from tensorflow.python.platform import test @@ -45,7 +45,7 @@ class ScopeTest(test.TestCase): scope.mark_read(QN('bar')) self.assertFalse(scope.has(QN('bar'))) - def test_copy(self): + def test_copy_from(self): scope = activity.Scope(None) scope.mark_write(QN('foo')) @@ -65,6 +65,17 @@ class ScopeTest(test.TestCase): self.assertTrue(QN('bar') in scope.created) self.assertFalse(QN('bar') in other.created) + def test_copy_of(self): + scope = activity.Scope(None) + scope.mark_read(QN('foo')) + + self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used) + + child_scope = activity.Scope(scope) + child_scope.mark_read(QN('bar')) + + self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used) + def test_nesting(self): scope = activity.Scope(None) scope.mark_write(QN('foo')) @@ -97,7 +108,7 @@ class ScopeTest(test.TestCase): self.assertFalse(QN('a') in child.referenced) -class ActivityAnalizerTest(test.TestCase): +class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) @@ -108,6 +119,7 @@ class ActivityAnalizerTest(test.TestCase): namespace={}, arg_values=None, arg_types=None, + owner_type=None, recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) @@ -132,10 +144,21 @@ class ActivityAnalizerTest(test.TestCase): anno.getanno(node.body[0].body[2].value, NodeAnno.IS_LOCAL)) # b in return b - def assertScopeIs(self, scope, used, modified, created): - self.assertItemsEqual(used, tuple(str(s) for s in scope.used)) - self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified)) - self.assertItemsEqual(created, tuple(str(s) for s in scope.created)) + def assertSymbolSetsAre(self, expected, actual, name): + expected = set(expected) + actual = set(str(s) for s in actual) + self.assertSetEqual( + expected, actual, 'for symbol set: %s\n' + ' Expected: %s\n' + ' Got: %s\n' + ' Missing: %s\n' + ' Extra: %s\n' % (name.upper(), expected, actual, + expected - actual, actual - expected)) + + def assertScopeIsRmc(self, scope, used, modified, created): + self.assertSymbolSetsAre(used, scope.used, 'read') + self.assertSymbolSetsAre(modified, scope.modified, 'modified') + self.assertSymbolSetsAre(created, scope.created, 'created') def test_print_statement(self): @@ -158,9 +181,9 @@ class ActivityAnalizerTest(test.TestCase): print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE) # We basically need to detect which variables are captured by the call # arguments. - self.assertScopeIs(print_args_scope, ('a', 'b'), (), ()) + self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ()) - def test_call(self): + def test_call_args(self): def test_fn(a): b = 0 @@ -172,9 +195,60 @@ class ActivityAnalizerTest(test.TestCase): call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) + def test_call_args_attributes(self): + + def foo(*_): + pass + + def test_fn(a): + a.c = 0 + foo(a.b, a.c) + return a.d + + node = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[1].value + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), + ('a', 'a.b', 'a.c'), + (), + (), + ) + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, + ('a', 'a.b', 'a.c', 'a.d', 'foo'), + ('a.c',), + ('a',), + ) + + def test_call_args_subscripts(self): + + def foo(*_): + pass + + def test_fn(a): + b = 1 + c = 2 + foo(a[0], a[b]) + return a[c] + + node = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[2].value + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), + ('a', 'a[0]', 'a[b]', 'b'), + (), + (), + ) + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, + ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'), + ('b', 'c'), + ('a', 'b', 'c'), + ) + def test_while(self): def test_fn(a): @@ -186,12 +260,14 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c'), ('a', 'b', 'c')) + self.assertScopeIsRmc( + anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ()) def test_for(self): @@ -204,9 +280,9 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'), ('b', 'c', '_'), ('a', 'b', 'c', '_')) @@ -225,46 +301,161 @@ class ActivityAnalizerTest(test.TestCase): node = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), ('y', 'z')) # TODO(mdan): Double check: is it ok to not mark a local symbol as not read? - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'), ('x', 'y', 'u'), ('y', 'u')) - self.assertScopeIs( + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - def test_call_with_composite_names(self): - - def foo(*_): - pass + def test_if_attributes(self): def test_fn(a): - foo(a.b, a.c) if a > 0: - a.b = 2 + a.b = -a.c + d = 2 * a + else: + a.b = a.c + d = 1 + return d + + node = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), + ('a', 'a.c'), + ('a.b', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'a.c'), + ('a.b', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, + ('a', 'a.c', 'd'), + ('a.b', 'd'), + ('a', 'd'), + ) + + def test_if_subscripts(self): + + def test_fn(a, b, c, e): + if a > 0: + a[b] = -a[c] + d = 2 * a else: - d = 2 - d.e = a.c - f = d.e + 1 - a.c = f + a[0] = e + d = 1 + return d node = self._parse_and_analyze(test_fn) - call_node = node.body[0].body[0].value - self.assertScopeIs( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), - ()) - if_node = node.body[0].body[1] - self.assertScopeIs( - anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) - self.assertScopeIs( + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), + ('a', 'b', 'c', 'a[c]'), + ('a', 'a[b]', 'd'), + ('d',), + ) + # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? + self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), - ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f')) + ('a', 'e'), + ('a', 'a[0]', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, + ('a', 'b', 'c', 'd', 'e', 'a[c]'), + ('a', 'd', 'a[b]', 'a[0]'), + ('a', 'b', 'c', 'd', 'e'), + ) + + def test_nested_if(self): + + def test_fn(b): + if b > 0: + if b < 5: + a = b + else: + a = b * b + return a + + node = self._parse_and_analyze(test_fn) + inner_if_node = node.body[0].body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), + ('a',)) + self.assertScopeIsRmc( + anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',), + ('a',)) + + def test_nested_function(self): + + def test_fn(a): + + def f(x): + y = x * x + return y + + b = a + for i in a: + c = b + b -= f(i) + return b, c + + node = self._parse_and_analyze(test_fn) + fn_def_node = node.body[0].body[0] + + self.assertScopeIsRmc( + anno.getanno(fn_def_node, + NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), + ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) + self.assertScopeIsRmc( + anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( + 'x', + 'y', + )) + + def test_constructor_attributes(self): + + class TestClass(object): + + def __init__(self, a): + self.b = a + self.b.c = 1 + + node = self._parse_and_analyze(TestClass) + init_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(init_node, NodeAnno.BODY_SCOPE), + ('self', 'a', 'self.b'), + ('self', 'self.b', 'self.b.c'), + ('self', 'a', 'self.b'), + ) + + def test_aug_assign_subscripts(self): + + def test_fn(a): + a[0] += 1 + + node = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a',), + ('a', 'a[0]'), + ('a',), + ) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py similarity index 67% rename from tensorflow/contrib/py2tf/pyct/static_analysis/annos.py rename to tensorflow/contrib/autograph/pyct/static_analysis/annos.py index 2d8e49442364fdd4a4752c8a83a5f3b76117fe57..b929b35b79200b0968c9c4f26b10cda28763773a 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/annos.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Annotations used by the static analizer.""" +"""Annotations used by the static analyzer.""" from __future__ import absolute_import from __future__ import division @@ -28,23 +28,32 @@ class NoValue(Enum): class NodeAnno(NoValue): - """Additionnal annotations used by the static analyzer. + """Additional annotations used by the static analyzer. These are in addition to the basic annotations declared in anno.py. """ # Symbols - - IS_LOCAL = 'Symbol is local to the function scope being analized.' - IS_PARAM = 'Symbol is a parameter to the function being analized.' + # These flags are boolean. + IS_LOCAL = 'Symbol is local to the function scope being analyzed.' + IS_PARAM = 'Symbol is a parameter to the function being analyzed.' IS_MODIFIED_SINCE_ENTRY = ( 'Symbol has been explicitly replaced in the current function scope.') # Scopes + # Scopes are represented by objects of type activity.Scope. ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' BODY_SCOPE = ( 'The scope for the main body of a statement (True branch for if ' 'statements, main body for loops).') ORELSE_SCOPE = ( 'The scope for the orelse body of a statement (False branch for if ' 'statements, orelse body for loops).') + + # Type and Value annotations + # Type annotations are represented by objects of type type_info.Type. + STATIC_INFO = ( + 'The type or value information that should be asserted about the entity ' + 'referenced by the symbol holding this annotation, irrespective of the ' + 'execution context.') diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py similarity index 85% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 9c0a9a9e74eccb3d22840032e8f0c2b81e051e7e..53ae15459097baff918432a493edd7360ebf209d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -25,9 +25,9 @@ from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno class LiveValueResolver(transformer.Base): @@ -55,11 +55,19 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() elif node.id in self.context.namespace: obj = self.context.namespace[node.id] anno.setanno(node, 'live_val', obj) - anno.setanno(node, 'fqn', (obj.__name__,)) + if hasattr(obj, '__name__'): + anno.setanno(node, 'fqn', (obj.__name__,)) + elif hasattr(obj, '__class__'): + obj_class = obj.__class__ + anno.setanno(node, 'fqn', + (obj_class.__module__, obj_class.__name__)) + else: + # If the symbol value is for example a primitive, then it will not + # have a name. + pass else: pass # TODO(mdan): Should we raise an error here? @@ -86,6 +94,7 @@ class LiveValueResolver(transformer.Base): if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) + anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. @@ -96,6 +105,7 @@ class LiveValueResolver(transformer.Base): # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. + anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py similarity index 75% rename from tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 9f64689401e3594a77fbdd7b6f02880bd6e90492..69e428bde109ed43c3cdda1a94970a832dc47852 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -18,13 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -46,6 +48,7 @@ class LiveValuesResolverTest(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, + owner_type=None, recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) @@ -56,13 +59,30 @@ class LiveValuesResolverTest(test.TestCase): def test_literals(self): + a = None + def test_fn(): - return Foo # pylint: disable=undefined-variable + return a - node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) + node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'}) retval_node = node.body[0].body[0].value self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) + def test_primitive_values(self): + + a = None + + def test_fn(): + return a + + node = self._parse_and_analyze(test_fn, {'a': True}) + retval_node = node.body[0].body[0].value + if six.PY2: + self.assertEqual( + anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool')) + else: + self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool')) + def test_namespace(self): def foo(): @@ -102,6 +122,7 @@ class LiveValuesResolverTest(test.TestCase): arg_types={'self': (TestClass.__name__, TestClass)}) func_node = node.body[0].body[0].value.func self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type')) self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py similarity index 52% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index 8203bda0f9a792a5b24b9abb25d8f39b61625748..2f553e1e23dfb7367fc9b6123222e685d0490780 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -14,22 +14,45 @@ # ============================================================================== """Type resolution. +This analyzer uses known live values to further infer object types. This +may include for instance constructed objects and object member functions. + +In addition, the analyzer will also process annotations for TF (staged) type +annotations. + Requires annotations generated by LiveValuesResolver. """ +# TODO(mdan): This would be more robust with a CFG. +# Situations with multiple reaching modifications (e.g. modified inside and +# outside a control flow statement) should be more robustly detected and +# analyzed. + +# TODO(mdan): Look into using Python AST's type annotation fields instead. +# It would be desirable to use that mechanism if we can. +# Some caveats to consider: We may need to annotate other nodes like +# Attribute. It may also not be feasible for us to faithfully to replicate +# PY3's type annotations where it isn't available. It would also require us +# to design rigorous type definitions that can accommodate Python types +# as well as TensorFLow dtypes and shapes. + + from __future__ import absolute_import from __future__ import division from __future__ import print_function import gast -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect +# TODO(mdan): Remove the duplication between this and activity.py. +# In particular, the symbol definitions we track here could as well be tracked +# there because they follow the same rules for visibility. class Scope(object): - """Encloses symbol value references. + """Tracks symbol value references. Attributes: values: A dict mapping string to gast.Node, containing the value that was @@ -79,20 +102,16 @@ class TypeInfoResolver(transformer.Base): def __init__(self, context): super(TypeInfoResolver, self).__init__(context) self.scope = Scope(None) - self.function_level = 0 def visit_FunctionDef(self, node): self.scope = Scope(self.scope) - self.function_level += 1 - self.generic_visit(node) - self.function_level -= 1 + node = self.generic_visit(node) self.scope = self.scope.parent return node def _visit_block(self, block): self.scope = Scope(self.scope) - for i, n in enumerate(block): - block[i] = self.generic_visit(n) + block = self.visit_block(block) self.scope = self.scope.parent return block @@ -117,7 +136,7 @@ class TypeInfoResolver(transformer.Base): def _process_function_arg(self, arg_name): str_name = str(arg_name) - if self.function_level == 1 and str_name in self.context.arg_types: + if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. type_holder = arg_name.ast() @@ -138,14 +157,18 @@ class TypeInfoResolver(transformer.Base): elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): # E.g. if we had # a = b - # then for future references to `a` we should have traced_source = `b` - traced_source = self.scope.getval(qn) - if anno.hasanno(traced_source, 'type'): - anno.setanno(node, 'type', anno.getanno(traced_source, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(traced_source, 'type_fqn')) + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + if anno.hasanno(definition, 'type'): + anno.setanno(node, 'type', anno.getanno(definition, 'type')) + anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) + if anno.hasanno(definition, 'element_type'): + anno.setanno(node, 'element_type', + anno.getanno(definition, 'element_type')) return node def _process_variable_assignment(self, source, targets): + # Special case: constructors. if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): @@ -158,16 +181,26 @@ class TypeInfoResolver(transformer.Base): # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - for t in targets: - if isinstance(t, gast.Tuple): - for i, e in enumerate(t.elts): - self.scope.setval( - anno.getanno(e, anno.Basic.QN), - gast.Subscript(source, gast.Index(i), ctx=gast.Store())) - elif isinstance(t, (gast.Name, gast.Attribute)): - self.scope.setval(anno.getanno(t, anno.Basic.QN), source) + # Multiple targets mean multiple assignment. + for target in targets: + # Tuple target means unpacking. + if isinstance(target, gast.Tuple): + for i, target_item in enumerate(target.elts): + # Two cases here: + # 1. Static unpacking, e.g. a, b = c, d + # 2. Dynamic unpacking, e.g. a, b = c + # The former case is optimized away. + if isinstance(source, (gast.Tuple, gast.List)): + source_item = source.elts[i] + else: + source_item = gast.Subscript(source, gast.Index(i), ctx=None) + self._process_variable_assignment(source_item, (target_item,)) + elif isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, source) else: - raise ValueError('Dont know how to handle assignment to %s' % t) + raise ValueError( + 'assignment target has unknown type: %s' % target_item) def visit_With(self, node): for wi in node.items: @@ -181,6 +214,41 @@ class TypeInfoResolver(transformer.Base): self._process_variable_assignment(node.value, node.targets) return node + def visit_Call(self, node): + if anno.hasanno(node.func, 'live_val'): + # Symbols targeted by the "set_type" marker function are assigned the data + # type that it specified. + if (anno.getanno(node.func, 'live_val') is + self.context.type_annotation_func): + + if len(node.args) != 2: + raise ValueError('"%s" must have exactly two parameters' + % self.context.type_annotation_func) + target_arg, type_arg = node.args + if not anno.hasanno(target_arg, anno.Basic.QN): + raise ValueError('the first argument of "%s" must by a symbol' + % self.context.type_annotation_func) + if isinstance(type_arg, gast.Str): + element_type = type_arg.s + elif isinstance(type_arg, gast.Num): + element_type = type_arg.n + else: + if not anno.hasanno(type_arg, 'live_val'): + raise ValueError( + 'the second argument of "%s" must be statically resolvable' % + self.context.type_annotation_func) + element_type = anno.getanno(type_arg, 'live_val') + + target_symbol = anno.getanno(target_arg, anno.Basic.QN) + # Find the definition of this symbol and annotate it with the given + # data type. That in turn will cause future uses of the symbol + # to receive the same type annotation. + definition = self.scope.getval(target_symbol) + anno.setanno(node, 'element_type', element_type) + anno.setanno(definition, 'element_type', element_type) + # TODO(mdan): Should we update references between definition and here? + return self.generic_visit(node) + def resolve(node, context): return TypeInfoResolver(context).visit(node) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py similarity index 69% rename from tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py rename to tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 3659f949db9910534870d8dd9e42fd4ee8297253..46b7701624a43073fb7cc612d2678ab851513d91 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.pyct import anno -from tensorflow.contrib.py2tf.pyct import context -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names -from tensorflow.contrib.py2tf.pyct.static_analysis import activity -from tensorflow.contrib.py2tf.pyct.static_analysis import live_values -from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import activity +from tensorflow.contrib.autograph.pyct.static_analysis import live_values +from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.client import session from tensorflow.python.platform import test from tensorflow.python.training import training @@ -56,7 +57,10 @@ class ScopeTest(test.TestCase): class TypeInfoResolverTest(test.TestCase): - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + def _parse_and_analyze(self, + test_fn, + namespace, + arg_types=None): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( namer=None, @@ -65,7 +69,9 @@ class TypeInfoResolverTest(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, - recursive=True) + owner_type=None, + recursive=True, + type_annotation_func=utils.set_element_type) node = qual_names.resolve(node) node = activity.resolve(node, ctx) node = live_values.resolve(node, ctx, {}) @@ -174,6 +180,62 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) + def test_type_annotation(self): + + class Foo(object): + pass + + def test_fn(): + f = [] + f = utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_def = node.body[0].body[0].value + self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + + def test_nested_unpacking(self): + + class Foo(object): + pass + + class Bar(object): + pass + + def test_fn(): + a, (b, c) = (Foo(), (Bar(), Foo())) + return a, b, c + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) + a, b, c = node.body[0].body[1].value.elts + self.assertEquals(Foo, anno.getanno(a, 'type')) + self.assertEquals(Bar, anno.getanno(b, 'type')) + self.assertEquals(Foo, anno.getanno(c, 'type')) + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + self.assertFalse(anno.hasanno(c, 'live_val')) + + def test_inner_scope(self): + + def test_fn(): + a = [] + utils.set_element_type(a, 1) + for _ in a: + b = [] + utils.set_element_type(b, 2) + return a, b + + node = self._parse_and_analyze(test_fn, {'utils': utils}) + a, b = node.body[0].body[2].body[2].value.elts + self.assertEquals(1, anno.getanno(a, 'element_type')) + self.assertEquals(2, anno.getanno(b, 'element_type')) + self.assertFalse(anno.hasanno(a, 'type')) + self.assertFalse(anno.hasanno(b, 'type')) + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py similarity index 54% rename from tensorflow/contrib/py2tf/pyct/templates.py rename to tensorflow/contrib/autograph/pyct/templates.py index c40e4d0fb783191705a412ab2728daabb61eda0f..baf7923fff7c786c1abd05e11fa6ffdb8c8f0912 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -26,9 +26,9 @@ import textwrap import gast -from tensorflow.contrib.py2tf.pyct import ast_util -from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names +from tensorflow.contrib.autograph.pyct import ast_util +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names class ReplaceTransformer(gast.NodeTransformer): @@ -44,8 +44,6 @@ class ReplaceTransformer(gast.NodeTransformer): self.replacements = replacements self.in_replacements = False - # TODO(mdan): Make a more detailed pass and clean up if needed. - def visit_Expr(self, node): if (isinstance(node.value, gast.Name) and node.value.id in self.replacements): @@ -53,28 +51,110 @@ class ReplaceTransformer(gast.NodeTransformer): self.generic_visit(node) return node + def visit_keyword(self, node): + if node.arg in self.replacements: + repl = self.replacements[node.arg] + if isinstance(repl, gast.keyword): + return repl + elif (isinstance(repl, (list, tuple)) and repl and + all(isinstance(r, gast.keyword) for r in repl)): + return repl + # TODO(mdan): We may allow replacing with a string as well. + # For example, if one wanted to replace foo with bar in foo=baz, then + # we could allow changing just node arg, so that we end up with bar=baz. + raise ValueError( + 'a keyword argument may only be replaced by another keyword or a ' + 'non-empty list of keywords. Found: %s' % repl) + return self.generic_visit(node) + def visit_FunctionDef(self, node): node = self.generic_visit(node) if node.name in self.replacements: repl = self.replacements[node.name] if not isinstance(repl, (gast.Name, ast.Name)): raise ValueError( - 'A function name can only be replaced by a Name node. Found: %s' % + 'a function name can only be replaced by a Name node. Found: %s' % repl) node.name = repl.id return node + def _check_has_context(self, node): + if not node.ctx: + raise ValueError('node %s is missing ctx value' % node) + + def _check_inner_children_have_context(self, node): + if isinstance(node, gast.Attribute): + self._check_inner_children_have_context(node.value) + self._check_has_context(node) + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._check_inner_children_have_context(e) + self._check_has_context(node) + elif isinstance(node, gast.Dict): + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._check_inner_children_have_context(node.value) + self._check_inner_children_have_context(node.slice) + elif isinstance(node, gast.Slice): + self._check_inner_children_have_context(node.lower) + if node.upper: + self._check_inner_children_have_context(node.upper) + if node.step: + self._check_inner_children_have_context(node.step) + elif isinstance(node, gast.Name): + self._check_has_context(node) + elif isinstance(node, (gast.Str, gast.Num)): + pass + else: + raise ValueError('unexpected node type "%s"' % node) + def _set_inner_child_context(self, node, ctx): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, ctx) node.ctx = gast.Load() + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._set_inner_child_context(e, ctx) + node.ctx = ctx elif isinstance(node, gast.Name): node.ctx = ctx + elif isinstance(node, gast.Call): + self._set_inner_child_context(node.func, ctx) + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for a in node.args: + self._check_inner_children_have_context(a) + for k in node.keywords: + self._check_inner_children_have_context(k.value) + elif isinstance(node, gast.Dict): + # We may be able to override these to Load(), but for now it's simpler + # to just assert that they're set. + for e in node.keys: + self._check_inner_children_have_context(e) + for e in node.values: + self._check_inner_children_have_context(e) + elif isinstance(node, gast.Subscript): + self._set_inner_child_context(node.value, ctx) + self._check_inner_children_have_context(node.slice) elif isinstance(node, (gast.Str, gast.Num)): pass else: raise ValueError('unexpected node type "%s"' % node) + def visit_Attribute(self, node): + node = self.generic_visit(node) + if node.attr not in self.replacements: + return node + repl = self.replacements[node.attr] + if not isinstance(repl, gast.Name): + raise ValueError( + 'An attribute can only be replaced by a Name node. Found: %s' % repl) + node.attr = repl.id + return node + def visit_Name(self, node): if node.id not in self.replacements: return node @@ -150,3 +230,17 @@ def replace(template, **replacements): if isinstance(results, list): return [qual_names.resolve(r) for r in results] return qual_names.resolve(results) + + +def replace_as_expression(template, **replacements): + """Variant of replace that generates expressions, instead of code blocks.""" + replacement = replace(template, **replacements) + if len(replacement) != 1: + raise ValueError( + 'single expression expected; for more general templates use replace') + node = replacement[0] + if not isinstance(node, gast.Expr): + raise ValueError( + 'the template is expected to generate an expression node; instead ' + 'found %s' % node) + return node.value diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a01f8bf04c4faa6ec1779e0fb306155d99f5bd09 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/templates_test.py @@ -0,0 +1,168 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for templates module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import imp + +import gast + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.python.platform import test + + +class TemplatesTest(test.TestCase): + + def test_replace_tuple(self): + template = """ + def test_fn(a, c): + return b, + """ + + node = templates.replace(template, b=('a', 'c'))[0] + result, _ = compiler.ast_to_object(node) + + self.assertEquals((2, 3), result.test_fn(2, 3)) + + def test_replace_variable(self): + template = """ + def test_fn(a): + a += 1 + a = 2 * a + 1 + return b + """ + + node = templates.replace(template, a='b')[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_function_name(self): + template = """ + def fname(a): + a += 1 + a = 2 * a + 1 + return a + """ + + node = templates.replace(template, fname='test_fn')[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_code_block(self): + template = """ + def test_fn(a): + block + return a + """ + + node = templates.replace( + template, + block=[ + gast.Assign([ + gast.Name('a', None, None) + ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))), + ] * 2)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn(1)) + + def test_replace_attribute(self): + template = """ + def test_fn(a): + return a.foo + """ + + node = templates.replace(template, foo='b')[0] + result, _ = compiler.ast_to_object(node) + mod = imp.new_module('test') + mod.b = 3 + self.assertEquals(3, result.test_fn(mod)) + + with self.assertRaises(ValueError): + templates.replace(template, foo=1) + + def test_replace_call_keyword(self): + template = """ + def test_fn(): + def f(a, d, f): + return a + d + f + return f(1, kws=None) + """ + + source = parser.parse_expression('f(d=3, f=5)') + node = templates.replace(template, kws=source.keywords)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(9, result.test_fn()) + + with self.assertRaises(ValueError): + templates.replace(template, kws=[]) + templates.replace(template, kws=1) + + def test_replace_name_with_call(self): + template = """ + def test_fn(): + b = 5 + def g(a): + return 3 * a + def f(): + return g + return foo + """ + + source = parser.parse_expression('f()(b)') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(15, result.test_fn()) + + def test_replace_name_with_dict(self): + template = """ + def test_fn(): + return foo['bar'] + """ + + source = parser.parse_expression('{\'bar\': 3}') + node = templates.replace(template, foo=source)[0] + result, _ = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn()) + + def replace_as_expression(self): + template = """ + foo(a) + """ + + node = templates.replace(template, foo='bar', a='baz') + self.assertTrue(node is gast.Call) + self.assertEqual(node.func.id, 'bar') + self.assertEqual(node.func.args[0].id, 'baz') + + def replace_as_expression_restrictions(self): + template = """ + foo(a) + bar(b) + """ + with self.assertRaises(ValueError): + templates.replace_as_expression(template) + with self.assertRaises(ValueError): + templates.replace('') + with self.assertRaises(ValueError): + templates.replace('a = b') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e102ab763011e744863e95e858638b6bea689d18 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -0,0 +1,143 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A node transformer that includes utilities for SCT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import gast +import six + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import pretty_printer + + +class AutographParseError(SyntaxError): + pass + + +def try_ast_to_source(node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + + +class Base(gast.NodeTransformer): + """Base class for specialized transformers. + + Scope-local state tracking: to keep state across nodes, at the level of + (possibly nested) scopes, use enter/exit_local_scope and set/get_local. + You must call enter/exit_local_scope manually, but the transformer detects + when they are not properly paired. + """ + + def __init__(self, context): + """Initialize the transformer. Subclasses should call this. + + Args: + context: An EntityContext. + """ + self._lineno = 0 + self._col_offset = 0 + self.context = context + self._enclosing_entities = [] + + # A stack that allows keeping mutable, scope-local state where scopes may be + # nested. For example, it can be used to track the usage of break + # statements in each loop, where loops may be nested. + self._local_scope_state = [] + self.enter_local_scope() + + @property + def enclosing_entities(self): + return tuple(self._enclosing_entities) + + def enter_local_scope(self): + self._local_scope_state.append({}) + + def exit_local_scope(self): + return self._local_scope_state.pop() + + def set_local(self, name, value): + self._local_scope_state[-1][name] = value + + def get_local(self, name, default=None): + return self._local_scope_state[-1].get(name, default) + + def debug_print(self, node): + """Helper method useful for debugging.""" + if __debug__: + print(pretty_printer.fmt(node)) + return node + + def visit_block(self, nodes): + """Helper equivalent to generic_visit, but for node lists.""" + results = [] + for node in nodes: + replacement = self.visit(node) + if replacement: + if isinstance(replacement, (list, tuple)): + results.extend(replacement) + else: + results.append(replacement) + return results + + def visit(self, node): + source_code = self.context.source_code + source_file = self.context.source_file + did_enter_function = False + local_scope_state_size = len(self._local_scope_state) + + try: + if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): + self._enclosing_entities.append(node) + did_enter_function = True + + if source_code and hasattr(node, 'lineno'): + self._lineno = node.lineno + self._col_offset = node.col_offset + if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + return node + return super(Base, self).visit(node) + + except (ValueError, AttributeError, KeyError, NotImplementedError, + AssertionError) as e: + msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( + e.__class__.__name__, str(e), try_ast_to_source(node), + pretty_printer.fmt(node, color=False)) + if source_code: + line = source_code.splitlines()[self._lineno - 1] + else: + line = '' + six.reraise(AutographParseError, + AutographParseError( + msg, + (source_file, self._lineno, self._col_offset + 1, line)), + sys.exc_info()[2]) + finally: + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_state_size != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f96b0dc377521a482d347436caa98633a0a32c8a --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -0,0 +1,179 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for templates module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.python.platform import test + + +class TransformerTest(test.TestCase): + + def _context_for_nodetesting(self): + return context.EntityContext( + namer=None, + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None, + recursive=False) + + def test_entity_scope_tracking(self): + + class TestTransformer(transformer.Base): + + # The choice of note to assign to is arbitrary. Using Assign because it's + # easy to find in the tree. + def visit_Assign(self, node): + anno.setanno(node, 'enclosing_entities', self.enclosing_entities) + return self.generic_visit(node) + + # This will show up in the lambda function. + def visit_BinOp(self, node): + anno.setanno(node, 'enclosing_entities', self.enclosing_entities) + return self.generic_visit(node) + + tr = TestTransformer(self._context_for_nodetesting()) + + def test_function(): + a = 0 + + class TestClass(object): + + def test_method(self): + b = 0 + def inner_function(x): + c = 0 + d = lambda y: (x + y) + return c, d + return b, inner_function + return a, TestClass + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + test_function_node = node.body[0] + test_class = test_function_node.body[1] + test_method = test_class.body[0] + inner_function = test_method.body[1] + lambda_node = inner_function.body[1].value + + a = test_function_node.body[0] + b = test_method.body[0] + c = inner_function.body[0] + lambda_expr = lambda_node.body + + self.assertEqual( + (test_function_node,), anno.getanno(a, 'enclosing_entities')) + self.assertEqual((test_function_node, test_class, test_method), + anno.getanno(b, 'enclosing_entities')) + self.assertEqual( + (test_function_node, test_class, test_method, inner_function), + anno.getanno(c, 'enclosing_entities')) + self.assertEqual((test_function_node, test_class, test_method, + inner_function, lambda_node), + anno.getanno(lambda_expr, 'enclosing_entities')) + + def test_statement_info_stack(self): + + class TestTransformer(transformer.Base): + + # Extract all string constants from the block. + def visit_Str(self, node): + self.set_local('string', self.get_local('string', default='') + node.s) + return self.generic_visit(node) + + def _annotate_result(self, node): + self.enter_local_scope() + node = self.generic_visit(node) + anno.setanno(node, 'test', self.get_local('string')) + self.exit_local_scope() + return node + + def visit_While(self, node): + return self._annotate_result(node) + + def visit_For(self, node): + return self._annotate_result(node) + + tr = TestTransformer(self._context_for_nodetesting()) + + def test_function(a): + """Docstring.""" + assert a == 'This should not be counted' + for i in range(3): + _ = 'a' + if i > 2: + return 'b' + else: + _ = 'c' + while True: + raise '1' + return 'nor this' + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + for_node = node.body[0].body[2] + while_node = for_node.body[1].orelse[1] + + self.assertFalse(anno.hasanno(for_node, 'string')) + self.assertEqual('abc', anno.getanno(for_node, 'test')) + self.assertFalse(anno.hasanno(while_node, 'string')) + self.assertEqual('1', anno.getanno(while_node, 'test')) + + def test_statement_info_stack_checks_integrity(self): + + class TestTransformer(transformer.Base): + + def visit_If(self, node): + self.enter_local_scope() + return self.generic_visit(node) + + def visit_For(self, node): + node = self.generic_visit(node) + self.exit_local_scope() + return node + + tr = TestTransformer(self._context_for_nodetesting()) + + def no_exit(a): + if a > 0: + print(a) + return None + + node, _ = parser.parse_entity(no_exit) + with self.assertRaises(AssertionError): + tr.visit(node) + + def no_entry(a): + for _ in a: + print(a) + + node, _ = parser.parse_entity(no_entry) + with self.assertRaises(AssertionError): + tr.visit(node) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD similarity index 88% rename from tensorflow/contrib/py2tf/utils/BUILD rename to tensorflow/contrib/autograph/utils/BUILD index c2fdd40707775783140390e4b5c0186c9c3e562e..d3a1b9468892531cbc51bc13de66ef595f1a95f8 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -20,26 +20,31 @@ py_library( name = "utils", srcs = [ "__init__.py", + "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", - "printing.py", "py_func.py", "tensor_list.py", + "testing.py", "type_check.py", + "type_hints.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", + "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], ) py_test( - name = "context_managers_test", - srcs = ["context_managers_test.py"], + name = "builtins_test", + srcs = ["builtins_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", @@ -47,8 +52,8 @@ py_test( ) py_test( - name = "misc_test", - srcs = ["misc_test.py"], + name = "context_managers_test", + srcs = ["context_managers_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -57,8 +62,8 @@ py_test( ) py_test( - name = "multiple_dispatch_test", - srcs = ["multiple_dispatch_test.py"], + name = "misc_test", + srcs = ["misc_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -67,8 +72,8 @@ py_test( ) py_test( - name = "py_func_test", - srcs = ["py_func_test.py"], + name = "multiple_dispatch_test", + srcs = ["multiple_dispatch_test.py"], srcs_version = "PY2AND3", deps = [ ":utils", @@ -77,9 +82,10 @@ py_test( ) py_test( - name = "printing_test", - srcs = ["printing_test.py"], + name = "py_func_test", + srcs = ["py_func_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..817d4126d106487e1fea3e442712a69bbfccd7f3 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility module that contains APIs usable in the generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin +from tensorflow.contrib.autograph.utils.builtins import dynamic_print +from tensorflow.contrib.autograph.utils.builtins import dynamic_range +from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns +from tensorflow.contrib.autograph.utils.misc import alias_tensors +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is +from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not +from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond +from tensorflow.contrib.autograph.utils.py_func import wrap_py_func +from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append +from tensorflow.contrib.autograph.utils.testing import fake_tf +from tensorflow.contrib.autograph.utils.type_check import is_tensor +from tensorflow.contrib.autograph.utils.type_hints import set_element_type diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..211e8eaee9082dd3e4f035e4379871cd2e154a39 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -0,0 +1,105 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builtin conversion utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops + + +def dynamic_builtin(f, *args, **kwargs): + """Converts a builtin function call inline.""" + if f is len: + return dynamic_len(*args, **kwargs) + if six.PY2 and f is xrange: + return dynamic_range(*args, **kwargs) + if f is range: + return dynamic_range(*args, **kwargs) + raise ValueError('%s is not supported' % f) + + +def dynamic_len(list_or_tensor): + """Implementation of len using dynamic dispatch.""" + if tensor_util.is_tensor(list_or_tensor): + shape = list_or_tensor.shape + if not shape: + raise ValueError( + 'len requires non-zero rank for tensor "%s"' % list_or_tensor) + return array_ops.shape(list_or_tensor)[0] + return len(list_or_tensor) + + +def dynamic_range(start_or_stop, stop=None, step=None): + """Implementation of range using dynamic dispatch.""" + if type_check.is_tensor(start_or_stop, stop, step): + if step is not None: + return math_ops.range(start_or_stop, stop, step) + if stop is not None: + return math_ops.range(start_or_stop, stop) + return math_ops.range(start_or_stop) + + if step is not None: + return range(start_or_stop, stop, step) + elif stop is not None: + return range(start_or_stop, stop) + return range(start_or_stop) + + +def is_tf_print_compatible(value): + # TODO(mdan): Enable once we can reliably test this. + # This is currently disabled because we can't capture the output of + # op kernels from Python. + del value + return False + + +def dynamic_print(*values): + """Implementation of print using dynamic dispatch. + + The function attempts to use tf.Print if all the values are compatible. + Otherwise, it will fall back to py_func. + + Args: + *values: values to print + Returns: + A dummy value indicating the print completed. If tf. + """ + + if all(map(is_tf_print_compatible, values)): + return logging_ops.Print(1, values) + + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) + print(*vals) + # The flush helps avoid garbled output in IPython. + sys.stdout.flush() + + return py_func.wrap_py_func( + print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py new file mode 100644 index 0000000000000000000000000000000000000000..163e6984079fea5c3b3d9aeda0ec8048d651686f --- /dev/null +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class BuiltinsTest(test.TestCase): + + def test_dynamic_len_tf_scalar(self): + a = constant_op.constant(1) + + with self.assertRaises(ValueError): + with self.test_session() as sess: + sess.run(builtins.dynamic_builtin(len, a)) + + def test_dynamic_len_tf_array(self): + a = constant_op.constant([1, 2, 3]) + + with self.test_session() as sess: + self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_tf_matrix(self): + a = constant_op.constant([[1, 2], [3, 4]]) + + with self.test_session() as sess: + self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_py_list(self): + a = [3] * 5 + + self.assertEqual(5, builtins.dynamic_builtin(len, a)) + + def test_dynamic_range_all_python(self): + self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) + self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) + + def test_dynamic_range_tf(self): + with self.test_session() as sess: + self.assertAllEqual( + sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), + [0, 1, 2]) + self.assertAllEqual( + sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), + [1, 2]) + self.assertAllEqual( + sess.run( + builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), + [2, 1]) + + def test_dynamic_range_detection(self): + def range(x): # pylint:disable=redefined-builtin + return x + + # Functions that just have the names of builtins are rejected. + with self.assertRaises(ValueError): + self.assertEqual(builtins.dynamic_builtin(range, 1), 1) + if six.PY2: + self.assertListEqual( + list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) + self.assertListEqual( + list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + + def test_dynamic_print_tf(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_dynamic_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/context_managers.py b/tensorflow/contrib/autograph/utils/context_managers.py similarity index 85% rename from tensorflow/contrib/py2tf/utils/context_managers.py rename to tensorflow/contrib/autograph/utils/context_managers.py index 38d9e11fe9069722b9023fee848bf53e1f72de6a..3d150a95817b83c4d7aaa78dc250092dcc4c5a9b 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers.py +++ b/tensorflow/contrib/autograph/utils/context_managers.py @@ -21,6 +21,7 @@ from __future__ import print_function import contextlib from tensorflow.python.framework import ops +from tensorflow.python.ops import tensor_array_ops def control_dependency_on_returns(return_value): @@ -34,9 +35,15 @@ def control_dependency_on_returns(return_value): Returns: A context manager. """ + def control_dependency_handle(t): + if isinstance(t, tensor_array_ops.TensorArray): + return t.flow + return t + if return_value is None: return contextlib.contextmanager(lambda: (yield))() # TODO(mdan): Filter to tensor objects. if not isinstance(return_value, (list, tuple)): return_value = (return_value,) + return_value = tuple(control_dependency_handle(t) for t in return_value) return ops.control_dependencies(return_value) diff --git a/tensorflow/contrib/py2tf/utils/context_managers_test.py b/tensorflow/contrib/autograph/utils/context_managers_test.py similarity index 82% rename from tensorflow/contrib/py2tf/utils/context_managers_test.py rename to tensorflow/contrib/autograph/utils/context_managers_test.py index 633ba93540e696889a6b2b71b40b999da39d48ff..42e27724b9856f715b524cdd7539897851715638 100644 --- a/tensorflow/contrib/py2tf/utils/context_managers_test.py +++ b/tensorflow/contrib/autograph/utils/context_managers_test.py @@ -18,8 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import context_managers +from tensorflow.contrib.autograph.utils import context_managers from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -32,6 +34,9 @@ class ContextManagersTest(test.TestCase): with context_managers.control_dependency_on_returns( constant_op.constant(1)): pass + with context_managers.control_dependency_on_returns( + tensor_array_ops.TensorArray(dtypes.int32, size=1)): + pass with context_managers.control_dependency_on_returns( [constant_op.constant(1), constant_op.constant(2)]): diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/autograph/utils/misc.py similarity index 100% rename from tensorflow/contrib/py2tf/utils/misc.py rename to tensorflow/contrib/autograph/utils/misc.py diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py similarity index 78% rename from tensorflow/contrib/py2tf/utils/misc_test.py rename to tensorflow/contrib/autograph/utils/misc_test.py index bfcb304c838df69e9e3961907362c7939c065117..71e358c33e1ea9887d267c67bc80362bac26c3a6 100644 --- a/tensorflow/contrib/py2tf/utils/misc_test.py +++ b/tensorflow/contrib/autograph/utils/misc_test.py @@ -18,29 +18,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import misc -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import variables +from tensorflow.contrib.autograph.utils.misc import alias_tensors +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.ops.variables import Variable from tensorflow.python.platform import test -class ContextManagersTest(test.TestCase): +class MiscTest(test.TestCase): def test_alias_single_tensor(self): - a = constant_op.constant(1) + a = constant(1) - new_a = misc.alias_tensors(a) + new_a = alias_tensors(a) self.assertFalse(new_a is a) with self.test_session() as sess: self.assertEqual(1, sess.run(new_a)) def test_alias_tensors(self): - a = constant_op.constant(1) - v = variables.Variable(2) + a = constant(1) + v = Variable(2) s = 'a' l = [1, 2, 3] - new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l) + new_a, new_v, new_s, new_l = alias_tensors(a, v, s, l) self.assertFalse(new_a is a) self.assertTrue(new_v is v) diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/contrib/autograph/utils/multiple_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..70eef5676f61bcd978ea53260f0b86a817f2bd7c --- /dev/null +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch.py @@ -0,0 +1,66 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for type-dependent behavior used in autograph-generated code.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils.type_check import is_tensor +from tensorflow.python.ops import control_flow_ops + + +def dynamic_is(left, right): + # TODO(alexbw) if we're sure we should leave 'is' in place, + # then change the semantics in converters/logical_expressions.py + return left is right + + +def dynamic_is_not(left, right): + return left is not right + + +def run_cond(condition, true_fn, false_fn): + """Type-dependent functional conditional. + + Args: + condition: A Tensor or Python bool. + true_fn: A Python callable implementing the true branch of the conditional. + false_fn: A Python callable implementing the false branch of the + conditional. + + Returns: + result: The result of calling the appropriate branch. If condition is a + Tensor, tf.cond will be used. Otherwise, a standard Python if statement will + be ran. + """ + if is_tensor(condition): + return control_flow_ops.cond(condition, true_fn, false_fn) + else: + return py_cond(condition, true_fn, false_fn) + + +def py_cond(condition, true_fn, false_fn): + """Functional version of Python's conditional.""" + if condition: + results = true_fn() + else: + results = false_fn() + + # The contract for the branch functions is to return tuples, but they should + # be collapsed to a single element when there is only one output. + if len(results) == 1: + return results[0] + return results diff --git a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py similarity index 50% rename from tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py rename to tensorflow/contrib/autograph/utils/multiple_dispatch_test.py index 5bb4d4086b002211eebb86783bb7212c707a1418..f72f8e94a0df815f7d517e2b81ffc86c5c545f07 100644 --- a/tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import multiple_dispatch + +import numpy as np + +from tensorflow.contrib.autograph.utils import multiple_dispatch from tensorflow.python.client.session import Session from tensorflow.python.framework.constant_op import constant from tensorflow.python.platform import test @@ -25,44 +28,47 @@ from tensorflow.python.platform import test class MultipleDispatchTest(test.TestCase): + def test_dynamic_is_python(self): + a = np.eye(3) + also_a = a + not_actually_a = np.eye(3) + should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) + should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) + should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) + should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) + + def test_dynamic_is_tf(self): + with Session().as_default(): + a = constant([2.0]) + also_a = a + not_actually_a = constant([2.0]) + should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) + should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) + should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) + should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) + self.assertTrue(should_be_true1) + self.assertTrue(should_be_true2) + self.assertFalse(should_be_false1) + self.assertFalse(should_be_false2) + def test_run_cond_python(self): - true_fn = lambda: 2.0 - false_fn = lambda: 3.0 - self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) - self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + true_fn = lambda: (2,) + false_fn = lambda: (3,) + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3) def test_run_cond_tf(self): - - true_fn = lambda: constant([2.0]) - false_fn = lambda: constant([3.0]) + true_fn = lambda: (constant(2),) + false_fn = lambda: (constant(3),) with Session() as sess: out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) - self.assertEqual(sess.run(out), 2.0) + self.assertEqual(sess.run(out), 2) out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) - self.assertEqual(sess.run(out), 3.0) - - def test_run_while_python(self): - cond_fn = lambda x, t, s: x > t - body_fn = lambda x, t, s: (x * s, t, s) - - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5]) - self.assertEqual(x, 0.75) - - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5]) - self.assertEqual(x, 3.0) - - def test_run_while_tf(self): - cond_fn = lambda x, t, s: x > t - body_fn = lambda x, t, s: (x * s, t, s) - - with Session() as sess: - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, - [constant(3.0), 1.0, 0.5]) - self.assertEqual(sess.run(x), 0.75) - - x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, - [constant(3.0), 4.0, 0.5]) - self.assertEqual(sess.run(x), 3.0) + self.assertEqual(sess.run(out), 3) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/contrib/autograph/utils/py_func.py new file mode 100644 index 0000000000000000000000000000000000000000..11ebfb2e49f0e762b56ae2cde2b76d2e24032d72 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pyfunc creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import script_ops + + +class MatchDType(namedtuple('MatchDType', ('arg_number',))): + """Allows matching the dtype of an argument. + + Used in conjunction with function calls. For example, MatchDType(0) will + match the DType of the first argument. + """ + + pass + + +def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False): + """Helper that wraps a callable to py_func. + + The helper passes tensor arguments through the py_func interface. Non-tensor + arguments are allowed, and will be passed to f directly. Note that non-tensor + arguments are captured by f will not update every time the wrapper is + called (this is consistent with its argument list, which only includes + the tensor arguments). In general, it's safest not to reuse this wrapper. + + Args: + f: Callable + return_dtypes: None, individual of tuple/list of DType or MatchDType, the + data type for each of f's return value(s). Set to None if f has no + return values or use_dummy_return is True. Use MatchDType to define a + dtype identical to that of `i`th argument (argument 0 is the first); + an argument must of Tensor type if it is to be used with MatchDType. + args: Positional arguments for f, as list or tuple. + kwargs: Keyword arguments for f, as dict with string keys. May be None. + use_dummy_return: If True, the function will return a dummy value of 1 + and discard its actual return value. + Returns: + The return values of f converted to tensor. + Raises: + ValueError: if any of the arguments are incorrect. + """ + + if return_dtypes and use_dummy_return: + raise ValueError('if use_dummy_return is True, return_dtypes must be empty') + + tensor_args = [] + tensor_args_idx = {} + + # Of the positional arguments, only grab the tensor ones to be passed through + # the py_func. + n_args = len(args) + arg_is_tensor = tuple(map(tensor_util.is_tensor, args)) + for i in range(n_args): + if arg_is_tensor[i]: + tensor_args_idx[i] = len(tensor_args) + tensor_args.append(args[i]) + + # We essentially take the tensor kwargs, if any, and add them to the list of + # positional arguments. The kwargs are then reconstructed inside the py_func. + # + # For example, if + # + # args = [Tensor(1), 'foo'] + # kwargs = {'a': Tensor(2), 'b': 'bar'} + # + # Then + # + # tensor_args = (Tensor(1), Tensor(2)) + # kwarg_keys = ('a', 'b') + if kwargs: + kwarg_keys = tuple(kwargs.keys()) + kwarg_is_tensor = {k: tensor_util.is_tensor(kwargs[k]) for k in kwarg_keys} + for k in kwarg_keys: + if kwarg_is_tensor[k]: + tensor_args_idx[k] = len(tensor_args) + tensor_args.append(kwargs[k]) + else: + kwarg_keys = () + + # Set up return dtypes. + def match_arg_dtype(arg_number): + arg = args[arg_number] + if not arg_is_tensor[arg_number]: + raise ValueError( + 'argument %d was used with MatchDType and must be a tf.Tensor, but ' + 'was %s instead' % (arg_number, type(arg))) + return arg.dtype + + if return_dtypes: + if isinstance(return_dtypes, MatchDType): + return_dtypes = match_arg_dtype(return_dtypes.arg_number) + elif isinstance(return_dtypes, (list, tuple)): + return_dtypes = tuple( + match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a + for a in return_dtypes) + else: + assert isinstance(return_dtypes, dtypes.DType) + + def f_wrapper(*tensor_args): + f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a + for i, a in enumerate(args)) + f_kwargs = { + k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k] + for i, k in enumerate(kwarg_keys) + } + retval = f(*f_args, **f_kwargs) + return 1 if use_dummy_return else retval + + return script_ops.py_func(f_wrapper, tensor_args, dtypes.int64 + if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2468263142f14332e86db99d198ba0f5c633dc69 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/py_func_test.py @@ -0,0 +1,103 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrap_py_func module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class PyFuncTest(test.TestCase): + + def test_wrap_py_func_simple(self): + + def test_fn(a, b, c): + return a + b + c + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (1, constant_op.constant(1), 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1)) + self.assertEqual(3, sess.run(result)) + result = py_func.wrap_py_func( + test_fn, dtypes.int64, + (constant_op.constant(1), 1, constant_op.constant(1))) + self.assertEqual(3, sess.run(result)) + + def test_wrap_py_func_complex_args(self): + + class TestClass(object): + + def __init__(self): + self.foo = 5 + + def test_fn(a, b): + return a * b.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass())) + self.assertEqual(35, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass())) + self.assertEqual(35, sess.run(result)) + + def test_wrap_py_func_kwargs(self): + + class TestClass(object): + + def __init__(self, foo): + self.foo = foo + + def test_fn(a, b, c, d): + return a * b.foo + c * d.foo + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), { + 'c': 11, + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + result = py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass(5)), { + 'c': constant_op.constant(11), + 'd': TestClass(13) + }) + self.assertEqual(178, sess.run(result)) + + def test_wrap_py_func_dummy_return(self): + + side_counter = [0] + + def test_fn(_): + side_counter[0] += 1 + + with self.test_session() as sess: + result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([1], side_counter) + result = py_func.wrap_py_func( + test_fn, None, (constant_op.constant(5),), use_dummy_return=True) + self.assertEqual(1, sess.run(result)) + self.assertEqual([2], side_counter) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/tensor_list.py b/tensorflow/contrib/autograph/utils/tensor_list.py similarity index 67% rename from tensorflow/contrib/py2tf/utils/tensor_list.py rename to tensorflow/contrib/autograph/utils/tensor_list.py index b6ff49e2a0eff384f10903e12212ab929e267804..2556f412891b4f0b954af5a6f0193341a6a5020a 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list.py +++ b/tensorflow/contrib/autograph/utils/tensor_list.py @@ -18,7 +18,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +def dynamic_list_append(target, element): + """Converts a list append call inline.""" + if isinstance(target, tensor_array_ops.TensorArray): + return target.write(target.size(), element) + # TODO(mdan): What's the right way to check this? + # TODO(mdan): We may not need this branch. + # It may be possible to use TensorList alone if the loop body will not + # require wrapping it, although we'd have to think about an autoboxing + # mechanism for lists received as parameter. + if isinstance(target, ops.Tensor): + return list_ops.tensor_list_push_back(target, element) + + # Python targets (including TensorList): fallback to their original append. + target.append(element) + return target class TensorList(object): diff --git a/tensorflow/contrib/py2tf/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py similarity index 71% rename from tensorflow/contrib/py2tf/utils/tensor_list_test.py rename to tensorflow/contrib/autograph/utils/tensor_list_test.py index b5e554a162674e08da21785dcbe193c54647f128..d58489eb68b6b949a4276520605c62b7c2825558 100644 --- a/tensorflow/contrib/py2tf/utils/tensor_list_test.py +++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py @@ -12,22 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for PyFlow list.""" +"""Tests for Autograph lists.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import tensor_list as tl +from tensorflow.contrib.autograph.utils import tensor_list as tl from tensorflow.python.client.session import Session from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.constant_op import constant +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test class TensorListTest(test.TestCase): + def _shape(self, shape_tuple): + return constant(shape_tuple, dtypes.int32) + + def test_dynamic_list_append(self): + l = [] + l = tl.dynamic_list_append(l, 1) + self.assertListEqual(l, [1]) + + l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32) + l = tl.dynamic_list_append(l, 1) + s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(s), [1]) + + l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) + l = tl.dynamic_list_append(l, 1) + s = l.stack() + with self.test_session() as sess: + self.assertAllEqual(sess.run(s), [1]) + + l = tl.TensorList(self._shape(()), dtypes.int32) + l = tl.dynamic_list_append(l, 1) + with self.test_session() as sess: + self.assertAllEqual(sess.run(l[0]), 1) + def test_list_append_python(self): with context.eager_mode(): a = constant(3.0) diff --git a/tensorflow/contrib/py2tf/pyct/parser.py b/tensorflow/contrib/autograph/utils/testing.py similarity index 65% rename from tensorflow/contrib/py2tf/pyct/parser.py rename to tensorflow/contrib/autograph/utils/testing.py index dc7df883b349becd860bb0dbceab22cb39c750b5..cb4785d0dc0f4674b3560418daeb6733364b21e7 100644 --- a/tensorflow/contrib/py2tf/pyct/parser.py +++ b/tensorflow/contrib/autograph/utils/testing.py @@ -12,29 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Converting code to AST. - -Adapted from Tangent. -""" +"""Testing utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import textwrap - -import gast - -from tensorflow.python.util import tf_inspect - +import imp -def parse_entity(entity): - """Return the AST of given entity.""" - source = tf_inspect.getsource(entity) - source = textwrap.dedent(source) - return parse_str(source), source +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops -def parse_str(src): - """Return the AST of given piece of code.""" - return gast.parse(src) +def fake_tf(): + """Creates a fake module that looks like TensorFlow, for testing.""" + mod = imp.new_module('tensorflow') + mod_contents = dict() + mod_contents.update(math_ops.__dict__) + mod_contents.update(ops.__dict__) + mod_contents.update(mod.__dict__) + mod.__dict__.update(mod_contents) + return mod diff --git a/tensorflow/contrib/py2tf/utils/type_check.py b/tensorflow/contrib/autograph/utils/type_check.py similarity index 86% rename from tensorflow/contrib/py2tf/utils/type_check.py rename to tensorflow/contrib/autograph/utils/type_check.py index 9ca2dec872c8a9ca7bedaa8603f70e3214a3e24a..8748abc47bcfb55b4d0b11178a46816249732da9 100644 --- a/tensorflow/contrib/py2tf/utils/type_check.py +++ b/tensorflow/contrib/autograph/utils/type_check.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities used in py2tf-generated code.""" +"""Utilities used in autograph-generated code.""" from __future__ import absolute_import from __future__ import division @@ -22,12 +22,12 @@ from tensorflow.python.framework import tensor_util def is_tensor(*args): - """Check if all arguments are tensors. + """Check if any arguments are tensors. Args: *args: Python objects that may or may not be tensors. Returns: - True if all *args are TensorFlow types, False if one or more are not. + True if any *args are TensorFlow types, False if none are. """ return any([tensor_util.is_tensor(a) for a in args]) diff --git a/tensorflow/contrib/py2tf/utils/type_check_test.py b/tensorflow/contrib/autograph/utils/type_check_test.py similarity index 96% rename from tensorflow/contrib/py2tf/utils/type_check_test.py rename to tensorflow/contrib/autograph/utils/type_check_test.py index 7d0428e9cccecdc67511e236bc00655a055aea29..3b67b7194c5656b193d47860f93986a985cb1aef 100644 --- a/tensorflow/contrib/py2tf/utils/type_check_test.py +++ b/tensorflow/contrib/autograph/utils/type_check_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy -from tensorflow.contrib.py2tf.utils import type_check +from tensorflow.contrib.autograph.utils import type_check from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/contrib/autograph/utils/type_hints.py similarity index 54% rename from tensorflow/contrib/bayesflow/python/ops/optimizers.py rename to tensorflow/contrib/autograph/utils/type_hints.py index fb70628d1083836281e9327e83e109493276c64f..aeb9e545610460afbe364dfcfc7a54b9aede29fe 100644 --- a/tensorflow/contrib/bayesflow/python/ops/optimizers.py +++ b/tensorflow/contrib/autograph/utils/type_hints.py @@ -12,25 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Probabilistic optimizer modules. +"""No-op utilities that provide static type hints. -See ${python/contrib.bayesflow.optimizers}. +These are used when the data type is not known at creation, for instance in the +case of empty lists. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * -from tensorflow.contrib.bayesflow.python.ops.variational_sgd_optimizer import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ - 'SGLDOptimizer', - 'VariationalSGDOptimizer', -] +def set_element_type(entity, dtype, shape=None): + """Indicates that the entity is expected hold items of specified type. -remove_undocumented(__name__, _allowed_symbols) + This function is a no-op. Its presence merely marks the data type of its + argument. The staged TensorFlow ops will reflect and assert this data type. + + Args: + entity: A Tensor or TensorArray. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + Returns: + The value of entity, unchanged. + """ + del dtype + del shape + return entity diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index ee67909133fc26ba98355db05a4b90d3dfa6b97b..d65c990c87cbc316472237d183c03765416501e7 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -112,14 +112,3 @@ py_test( "//tensorflow/python:script_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD index 6db627faad1df4a4b73082e74e7754829ff2b514..7cb2d8079bd18660f72eab92654629434ce4d6a5 100644 --- a/tensorflow/contrib/batching/test_util/BUILD +++ b/tensorflow/contrib/batching/test_util/BUILD @@ -8,17 +8,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - cc_library( name = "fake_clock_env", testonly = 1, diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD index 2a84a7712a8fa66e89db41ff4e7ebe4f620029ca..8f81b6702f2807d7da7e72190ce2d86b28e52113 100644 --- a/tensorflow/contrib/batching/util/BUILD +++ b/tensorflow/contrib/batching/util/BUILD @@ -8,18 +8,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "**/google_*", - ], - ), -) - cc_library( name = "periodic_function_dynamic", hdrs = ["periodic_function.h"], diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 74712aeb67c3f0a31def78f25a0298f9c02c9590..5a2d7f6a3c0ba233299a5790fa80488786712f3c 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -37,126 +37,6 @@ py_library( ], ) -cuda_py_test( - name = "metropolis_hastings_test", - size = "medium", - srcs = ["python/kernel_tests/metropolis_hastings_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "csiszar_divergence_test", - size = "medium", - srcs = ["python/kernel_tests/csiszar_divergence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], - tags = [ - "manual", # b/64490288 - "notap", - ], -) - -cuda_py_test( - name = "custom_grad_test", - size = "small", - srcs = ["python/kernel_tests/custom_grad_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "layers_conv_variational_test", - size = "small", - srcs = ["python/kernel_tests/layers_conv_variational_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], -) - -cuda_py_test( - name = "layers_dense_variational_test", - size = "small", - srcs = ["python/kernel_tests/layers_dense_variational_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], -) - -cuda_py_test( - name = "mcmc_diagnostics_test", - size = "small", - srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:spectral_ops_test_util", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], -) - cuda_py_test( name = "monte_carlo_test", size = "small", @@ -177,117 +57,3 @@ cuda_py_test( "//tensorflow/python:random_seed", ], ) - -cuda_py_test( - name = "halton_sequence_test", - size = "small", - srcs = ["python/kernel_tests/halton_sequence_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], - tags = ["no_mac"], # b/73192243 -) - -cuda_py_test( - name = "hmc_test", - size = "medium", - srcs = ["python/kernel_tests/hmc_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], -) - -cuda_py_test( - name = "sgld_optimizer_test", - size = "small", - srcs = ["python/kernel_tests/sgld_optimizer_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["notsan"], -) - -cuda_py_test( - name = "variable_utils_test", - size = "small", - srcs = ["python/kernel_tests/variable_utils_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "variational_sgd_optimizer_test", - size = "small", - srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", - ], - tags = ["notsan"], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/bayesflow/README.md b/tensorflow/contrib/bayesflow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..10323dc6d59918a9f8cf1840d06dcd219dfe3568 --- /dev/null +++ b/tensorflow/contrib/bayesflow/README.md @@ -0,0 +1,17 @@ +# Notice + +`tf.contrib.bayesflow` has moved! + +See new code at [github.com/tensorflow/probability]( +https://github.com/tensorflow/probability). + +Switch imports with: + +```python +# old +import tensorflow as tf +tfp = tf.contrib.bayesflow + +# new +import tensorflow_probability as tfp +``` diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 528c4fbacd06c7b0defa0e32bd24a98b2bc07b64..41a8c920fc4e81af90f4c94a149d8c404c58b747 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -21,36 +21,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence -from tensorflow.contrib.bayesflow.python.ops import custom_grad -from tensorflow.contrib.bayesflow.python.ops import halton_sequence -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops import layers -from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo -from tensorflow.contrib.bayesflow.python.ops import optimizers -from tensorflow.contrib.bayesflow.python.ops import variable_utils # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'csiszar_divergence', - 'custom_grad', - 'entropy', - 'halton_sequence', - 'hmc', - 'layers', - 'metropolis_hastings', - 'mcmc_diagnostics', 'monte_carlo', - 'optimizers', - 'special_math', - 'stochastic_variables', - 'variable_utils', - 'variational_inference', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py deleted file mode 100644 index 2e94b7206de4f7c40c89f083f3bfa2a22bb7b917..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ /dev/null @@ -1,1004 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Csiszar Divergence Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence_impl -from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -cd = csiszar_divergence_impl - - -def tridiag(d, diag_value, offdiag_value): - """d x d matrix with given value on diag, and one super/sub diag.""" - diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) - three_bands = array_ops.matrix_band_part( - array_ops.fill([d, d], offdiag_value), 1, 1) - return diag_mat + three_bands - - -class AmariAlphaTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for alpha in [-1., 0., 1., 2.]: - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(0., alpha=alpha, - self_normalized=normalized).eval(), - 0.) - - def test_correct_when_alpha0(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0.).eval(), - -self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=0., self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - def test_correct_when_alpha1(self): - with self.test_session(): - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1.).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.amari_alpha(self._logu, alpha=1., self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - def test_correct_when_alpha_not_01(self): - for alpha in [-2, -1., -0.5, 0.5, 2.]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=False).eval(), - ((self._u**alpha - 1)) / (alpha * (alpha - 1.))) - - self.assertAllClose( - cd.amari_alpha(self._logu, - alpha=alpha, - self_normalized=True).eval(), - ((self._u**alpha - 1.) - - alpha * (self._u - 1)) / (alpha * (alpha - 1.))) - - -class KLReverseTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_reverse(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_reverse(self._logu).eval(), - -self._logu) - - self.assertAllClose( - cd.kl_reverse(self._logu, self_normalized=True).eval(), - -self._logu + (self._u - 1.)) - - -class KLForwardTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - for normalized in [True, False]: - with self.test_session(graph=ops.Graph()): - self.assertAllClose( - cd.kl_forward(0., self_normalized=normalized).eval(), - 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.kl_forward(self._logu).eval(), - self._u * self._logu) - - self.assertAllClose( - cd.kl_forward(self._logu, self_normalized=True).eval(), - self._u * self._logu - (self._u - 1.)) - - -class JensenShannonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jensen_shannon(0.).eval(), np.log(0.25)) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jensen_shannon).eval()) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - cd.symmetrized_csiszar_function( - self._logu, - lambda x: cd.jensen_shannon(x, self_normalized=True)).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jensen_shannon(self._logu).eval(), - (self._u * self._logu - - (1 + self._u) * np.log1p(self._u))) - - self.assertAllClose( - cd.jensen_shannon(self._logu, self_normalized=True).eval(), - (self._u * self._logu - - (1 + self._u) * np.log((1 + self._u) / 2))) - - -class ArithmeticGeometricMeanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.arithmetic_geometric(0.).eval(), np.log(4)) - self.assertAllClose( - cd.arithmetic_geometric(0., self_normalized=True).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.arithmetic_geometric).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.arithmetic_geometric(self._logu).eval(), - (1. + self._u) * np.log((1. + self._u) / np.sqrt(self._u))) - - self.assertAllClose( - cd.arithmetic_geometric(self._logu, self_normalized=True).eval(), - (1. + self._u) * np.log(0.5 * (1. + self._u) / np.sqrt(self._u))) - - -class TotalVariationTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.total_variation(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.total_variation(self._logu).eval(), - 0.5 * np.abs(self._u - 1)) - - -class PearsonTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.pearson(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.pearson(self._logu).eval(), - np.square(self._u - 1)) - - -class SquaredHellingerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.squared_hellinger(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.squared_hellinger).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.squared_hellinger(self._logu).eval(), - np.square(np.sqrt(self._u) - 1)) - - -class TriangularTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.triangular(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.triangular).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.triangular(self._logu).eval(), - np.square(self._u - 1) / (1 + self._u)) - - -class TPowerTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.t_power(0., t=-0.1).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=0.5).eval(), 0.) - self.assertAllClose(cd.t_power(0., t=1.1).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=-0.1, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=0.5, self_normalized=True).eval(), 0.) - self.assertAllClose( - cd.t_power(0., t=1.1, self_normalized=True).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1)).eval(), - self._u ** -0.1 - 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5)).eval(), - -self._u ** 0.5 + 1.) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1)).eval(), - self._u ** 1.1 - 1.) - - def test_correct_self_normalized(self): - with self.test_session(): - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(-0.1), - self_normalized=True).eval(), - self._u ** -0.1 - 1. + 0.1 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(0.5), - self_normalized=True).eval(), - -self._u ** 0.5 + 1. + 0.5 * (self._u - 1.)) - self.assertAllClose( - cd.t_power(self._logu, t=np.float64(1.1), - self_normalized=True).eval(), - self._u ** 1.1 - 1. - 1.1 * (self._u - 1.)) - - -class Log1pAbsTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.log1p_abs(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.log1p_abs(self._logu).eval(), - self._u**(np.sign(self._u - 1)) - 1) - - -class JeffreysTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.jeffreys(0.).eval(), 0.) - - def test_symmetric(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - cd.symmetrized_csiszar_function( - self._logu, cd.jeffreys).eval()) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.jeffreys(self._logu).eval(), - 0.5 * (self._u * self._logu - self._logu)) - - -class ChiSquareTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose(cd.chi_square(0.).eval(), 0.) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.chi_square(self._logu).eval(), - self._u**2 - 1) - - -class ModifiedGanTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10, 100) - self._u = np.exp(self._logu) - - def test_at_zero(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(0.).eval(), np.log(2)) - self.assertAllClose( - cd.modified_gan(0., self_normalized=True).eval(), np.log(2)) - - def test_correct(self): - with self.test_session(): - self.assertAllClose( - cd.modified_gan(self._logu).eval(), - np.log1p(self._u) - self._logu) - - self.assertAllClose( - cd.modified_gan(self._logu, self_normalized=True).eval(), - np.log1p(self._u) - self._logu + 0.5 * (self._u - 1)) - - -class SymmetrizedCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_jensen_shannon(self): - with self.test_session(): - - # The following functions come from the claim made in the - # symmetrized_csiszar_function docstring. - def js1(logu): - return (-logu - - (1. + math_ops.exp(logu)) * ( - nn_ops.softplus(logu))) - - def js2(logu): - return 2. * (math_ops.exp(logu) * ( - logu - nn_ops.softplus(logu))) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js1).eval(), - cd.jensen_shannon(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, js2).eval(), - cd.jensen_shannon(self._logu).eval()) - - def test_jeffreys(self): - with self.test_session(): - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.jeffreys(self._logu).eval()) - - self.assertAllClose( - cd.symmetrized_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.jeffreys(self._logu).eval()) - - -class DualCsiszarFunctionTest(test.TestCase): - - def setUp(self): - self._logu = np.linspace(-10., 10., 100) - self._u = np.exp(self._logu) - - def test_kl_forward(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_forward).eval(), - cd.kl_reverse(self._logu).eval()) - - def test_kl_reverse(self): - with self.test_session(): - self.assertAllClose( - cd.dual_csiszar_function(self._logu, cd.kl_reverse).eval(), - cd.kl_forward(self._logu).eval()) - - -class MonteCarloCsiszarFDivergenceTest(test.TestCase): - - def test_kl_forward(self): - with self.test_session() as sess: - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse(self): - with self.test_session() as sess: - - q = normal_lib.Normal( - loc=np.ones(6), - scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) - - p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.07, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.02, atol=0.) - - def test_kl_reverse_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - def test_kl_forward_multidim(self): - - with self.test_session() as sess: - d = 5 # Dimension - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.]*d) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_forward, - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_forward(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=int(1e5), - seed=1) - - exact_kl = kullback_leibler.kl_divergence(p, q) - - [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ - approx_kl, approx_kl_self_normalized, exact_kl]) - - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.06, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.05, atol=0.) - - def test_score_trick(self): - - with self.test_session() as sess: - d = 5 # Dimension - num_draws = int(1e5) - seed = 1 - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(d, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [d])) - - approx_kl = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - seed=seed) - - approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( - f=cd.kl_reverse, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed) - - approx_kl_self_normalized_score_trick = ( - cd.monte_carlo_csiszar_f_divergence( - f=lambda logu: cd.kl_reverse(logu, self_normalized=True), - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - use_reparametrization=False, - seed=seed)) - - exact_kl = kullback_leibler.kl_divergence(q, p) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - [ - approx_kl_grad_, - approx_kl_self_normalized_grad_, - approx_kl_score_trick_grad_, - approx_kl_self_normalized_score_trick_grad_, - exact_kl_grad_, - approx_kl_, - approx_kl_self_normalized_, - approx_kl_score_trick_, - approx_kl_self_normalized_score_trick_, - exact_kl_, - ] = sess.run([ - grad_sum(approx_kl), - grad_sum(approx_kl_self_normalized), - grad_sum(approx_kl_score_trick), - grad_sum(approx_kl_self_normalized_score_trick), - grad_sum(exact_kl), - approx_kl, - approx_kl_self_normalized, - approx_kl_score_trick, - approx_kl_self_normalized_score_trick, - exact_kl, - ]) - - # Test average divergence. - self.assertAllClose(approx_kl_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_, exact_kl_, - rtol=0.08, atol=0.) - - self.assertAllClose(approx_kl_score_trick_, exact_kl_, - rtol=0.02, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_score_trick_, exact_kl_, - rtol=0.08, atol=0.) - - # Test average gradient-divergence. - self.assertAllClose(approx_kl_grad_, exact_kl_grad_, - rtol=0.007, atol=0.) - - self.assertAllClose(approx_kl_self_normalized_grad_, exact_kl_grad_, - rtol=0.011, atol=0.) - - self.assertAllClose(approx_kl_score_trick_grad_, exact_kl_grad_, - rtol=0.018, atol=0.) - - self.assertAllClose( - approx_kl_self_normalized_score_trick_grad_, exact_kl_grad_, - rtol=0.017, atol=0.) - - -class CsiszarVIMCOTest(test.TestCase): - - def _csiszar_vimco_helper(self, logu): - """Numpy implementation of `csiszar_vimco_helper`.""" - - # Since this is a naive/intuitive implementation, we compensate by using the - # highest precision we can. - logu = np.float128(logu) - n = logu.shape[0] - u = np.exp(logu) - loogeoavg_u = [] # Leave-one-out geometric-average of exp(logu). - for j in range(n): - loogeoavg_u.append(np.exp(np.mean( - [logu[i, ...] for i in range(n) if i != j], - axis=0))) - loogeoavg_u = np.array(loogeoavg_u) - - loosum_u = [] # Leave-one-out sum of exp(logu). - for j in range(n): - loosum_u.append(np.sum( - [u[i, ...] for i in range(n) if i != j], - axis=0)) - loosum_u = np.array(loosum_u) - - # Natural log of the average u except each is swapped-out for its - # leave-`i`-th-out Geometric average. - log_sooavg_u = np.log(loosum_u + loogeoavg_u) - np.log(n) - - log_avg_u = np.log(np.mean(u, axis=0)) - return log_avg_u, log_sooavg_u - - def _csiszar_vimco_helper_grad(self, logu, delta): - """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`.""" - - # This code actually estimates the sum of the Jacobiab because that's what - # TF's `gradients` does. - np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper( - logu[..., None] + np.diag([delta]*len(logu))) - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper( - logu[..., None]) - return [ - (np_log_avg_u1 - np_log_avg_u) / delta, - np.sum(np_log_sooavg_u1 - np_log_sooavg_u, axis=0) / delta, - ] - - def test_vimco_helper_1(self): - """Tests that function calculation correctly handles batches.""" - - logu = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-8, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-8, atol=0.) - - def test_vimco_helper_2(self): - """Tests that function calculation correctly handles overflow.""" - - # Using 700 (rather than 1e3) since naive numpy version can't handle higher. - logu = np.float32([0., 700, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-6, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-5, atol=0.) - - def test_vimco_helper_3(self): - """Tests that function calculation correctly handles underlow.""" - - logu = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu) - [log_avg_u, log_sooavg_u] = sess.run(cd.csiszar_vimco_helper(logu)) - self.assertAllClose(np_log_avg_u, log_avg_u, - rtol=1e-5, atol=0.) - self.assertAllClose(np_log_sooavg_u, log_sooavg_u, - rtol=1e-4, atol=1e-15) - - def test_vimco_helper_gradient_using_finite_difference_1(self): - """Tests that gradient calculation correctly handles batches.""" - - logu_ = np.linspace(-100., 100., 100).reshape([10, 2, 5]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - # We skip checking against finite-difference approximation since it - # doesn't support batches. - - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_2(self): - """Tests that gradient calculation correctly handles overflow.""" - - delta = 1e-3 - logu_ = np.float32([0., 1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_helper_gradient_using_finite_difference_3(self): - """Tests that gradient calculation correctly handles underlow.""" - - delta = 1e-3 - logu_ = np.float32([0., -1000, -1, 1]) - with self.test_session() as sess: - logu = array_ops.constant(logu_) - - [ - np_grad_log_avg_u, - np_grad_log_sooavg_u, - ] = self._csiszar_vimco_helper_grad(logu_, delta) - - grad = lambda flogu: gradients_impl.gradients(flogu, logu)[0] - log_avg_u, log_sooavg_u = cd.csiszar_vimco_helper(logu) - - [ - grad_log_avg_u, - grad_log_sooavg_u, - ] = sess.run([grad(log_avg_u), grad(log_sooavg_u)]) - - self.assertAllClose(np_grad_log_avg_u, grad_log_avg_u, - rtol=delta, atol=0.) - self.assertAllClose(np_grad_log_sooavg_u, grad_log_sooavg_u, - rtol=delta, atol=0.) - # Verify claim in docstring. - self.assertAllClose( - np.ones_like(grad_log_avg_u.sum(axis=0)), - grad_log_avg_u.sum(axis=0)) - self.assertAllClose( - np.ones_like(grad_log_sooavg_u.mean(axis=0)), - grad_log_sooavg_u.mean(axis=0)) - - def test_vimco_and_gradient(self): - - with self.test_session() as sess: - dims = 5 # Dimension - num_draws = int(20) - num_batch_draws = int(3) - seed = 1 - - f = lambda logu: cd.kl_reverse(logu, self_normalized=False) - np_f = lambda logu: -logu - - p = mvn_full_lib.MultivariateNormalFullCovariance( - covariance_matrix=tridiag(dims, diag_value=1, offdiag_value=0.5)) - - # Variance is very high when approximating Forward KL, so we make - # scale_diag larger than in test_kl_reverse_multidim. This ensures q - # "covers" p and thus Var_q[p/q] is smaller. - s = array_ops.constant(1.) - q = mvn_diag_lib.MultivariateNormalDiag( - scale_diag=array_ops.tile([s], [dims])) - - vimco = cd.csiszar_vimco( - f=f, - p_log_prob=p.log_prob, - q=q, - num_draws=num_draws, - num_batch_draws=num_batch_draws, - seed=seed) - - x = q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed) - x = array_ops.stop_gradient(x) - logu = p.log_prob(x) - q.log_prob(x) - f_log_sum_u = f(cd.csiszar_vimco_helper(logu)[0]) - - grad_sum = lambda fs: gradients_impl.gradients(fs, s)[0] - - def jacobian(x): - # Warning: this function is slow and may not even finish if prod(shape) - # is larger than, say, 100. - shape = x.shape.as_list() - assert all(s is not None for s in shape) - x = array_ops.reshape(x, shape=[-1]) - r = [grad_sum(x[i]) for i in range(np.prod(shape))] - return array_ops.reshape(array_ops.stack(r), shape=shape) - - [ - logu_, - jacobian_logqx_, - vimco_, - grad_vimco_, - f_log_sum_u_, - grad_mean_f_log_sum_u_, - ] = sess.run([ - logu, - jacobian(q.log_prob(x)), - vimco, - grad_sum(vimco), - f_log_sum_u, - grad_sum(f_log_sum_u) / num_batch_draws, - ]) - - np_log_avg_u, np_log_sooavg_u = self._csiszar_vimco_helper(logu_) - - # Test VIMCO loss is correct. - self.assertAllClose(np_f(np_log_avg_u).mean(axis=0), vimco_, - rtol=1e-5, atol=0.) - - # Test gradient of VIMCO loss is correct. - # - # To make this computation we'll inject two gradients from TF: - # - grad[mean(f(log(sum(p(x)/q(x)))))] - # - jacobian[log(q(x))]. - # - # We now justify why using these (and only these) TF values for - # ground-truth does not undermine the completeness of this test. - # - # Regarding `grad_mean_f_log_sum_u_`, note that we validate the - # correctness of the zero-th order derivative (for each batch member). - # Since `cd.csiszar_vimco_helper` itself does not manipulate any gradient - # information, we can safely rely on TF. - self.assertAllClose(np_f(np_log_avg_u), f_log_sum_u_, rtol=1e-4, atol=0.) - # - # Regarding `jacobian_logqx_`, note that testing the gradient of - # `q.log_prob` is outside the scope of this unit-test thus we may safely - # use TF to find it. - - # The `mean` is across batches and the `sum` is across iid samples. - np_grad_vimco = ( - grad_mean_f_log_sum_u_ - + np.mean( - np.sum( - jacobian_logqx_ * (np_f(np_log_avg_u) - - np_f(np_log_sooavg_u)), - axis=0), - axis=0)) - - self.assertAllClose(np_grad_vimco, grad_vimco_, - rtol=1e-5, atol=0.) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py deleted file mode 100644 index a95df31ac1fd9f5038abe779391ccba5f7fe408d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/custom_grad_test.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Custom Gradient Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import custom_grad_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -cg = custom_grad_impl - - -class CustomGradientTest(test.TestCase): - - def test_works_correctly(self): - with self.test_session() as sess: - f = lambda x: x**2 / 2 - g = lambda x: (x - 1)**3 / 3 - x_ = np.linspace(-100, 100, int(1e4)) + [0.] - - x = constant_op.constant(x_) - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, x)[0] - [fx_, gx_] = sess.run([fx, gx]) - - self.assertAllClose(f(x_), fx_) - self.assertAllClose(g(x_), gx_) - - def test_works_correctly_both_f_g_zero(self): - with self.test_session() as sess: - f = lambda x: x**2 / 2 - g = lambda x: x**3 / 3 - x_ = np.linspace(-100, 100, int(1e4)) + [0.] - - x = constant_op.constant(x_) - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, x)[0] - [fx_, gx_] = sess.run([fx, gx]) - - self.assertAllClose(f(x_), fx_) - self.assertAllClose(g(x_), gx_) - - def test_works_correctly_vector_of_vars(self): - with self.test_session() as sess: - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(2)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(3)) - sess.run([variables.global_variables_initializer()]) - - f = lambda z: z[0] * z[1] - g = lambda z: z[0]**2 * z[1]**2 / 2 - - z = array_ops.stack([x, y]) - fz = cg.custom_gradient(f(z), g(z), z, axis=0) - gz = gradients_impl.gradients(fz, variables.trainable_variables()) - [z_, fz_, gx_, gy_] = sess.run([z, fz, gz[0], gz[1]]) - - self.assertEqual(f(z_), fz_) - self.assertEqual(g(z_), gx_) - self.assertEqual(g(z_), gy_) - - def test_works_correctly_side_vars(self): - with self.test_session() as sess: - x_ = np.float32(2.1) # Adding extra tenth to force imprecision. - y_ = np.float32(3.1) - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(x_)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(y_)) - sess.run([variables.global_variables_initializer()]) - - f = lambda x: x * y - g = lambda z: math_ops.square(x) * y - - fx = cg.custom_gradient(f(x), g(x), x) - gx = gradients_impl.gradients(fx, variables.trainable_variables()) - [x_, fx_, gx_] = sess.run([x, fx, gx[0]]) - gy_ = gx[1] - - self.assertEqual(x_ * y_, fx_) - self.assertEqual(np.square(x_) * y_, gx_) - self.assertEqual(None, gy_) - - def test_works_correctly_fx_gx_manually_stopped(self): - with self.test_session() as sess: - x_ = np.float32(2.1) # Adding extra tenth to force imprecision. - y_ = np.float32(3.1) - x = variable_scope.get_variable( - name="x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(x_)) - y = variable_scope.get_variable( - name="y", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(y_)) - sess.run([variables.global_variables_initializer()]) - - stop = array_ops.stop_gradient # For readability. - - # Basically we need to stop the `x` portion of `f`. And when we supply the - # arg to `custom_gradient` we need to stop the complement, i.e., the `y` - # part. - f = lambda x: stop(x) * y - g = lambda x: stop(math_ops.square(x)) * y - fx = cg.custom_gradient(f(x), g(x), x + stop(y), - fx_gx_manually_stopped=True) - - gx = gradients_impl.gradients(fx, variables.trainable_variables()) - [x_, fx_, gx_, gy_] = sess.run([x, fx, gx[0], gx[1]]) - - self.assertEqual(x_ * y_, fx_) - self.assertEqual(np.square(x_) * y_, gx_) - self.assertEqual(x_, gy_) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py deleted file mode 100644 index 0a85862abfd744a86b9a38e10dbb5b985d0a0e94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for halton_sequence.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test - - -mc = monte_carlo_lib - - -class HaltonSequenceTest(test.TestCase): - - def test_known_values_small_bases(self): - with self.test_session(): - # The first five elements of the Halton sequence with base 2 and 3 - expected = np.array(((1. / 2, 1. / 3), - (1. / 4, 2. / 3), - (3. / 4, 1. / 9), - (1. / 8, 4. / 9), - (5. / 8, 7. / 9)), dtype=np.float32) - sample = halton.sample(2, num_samples=5) - self.assertAllClose(expected, sample.eval(), rtol=1e-6) - - def test_sample_indices(self): - with self.test_session(): - dim = 5 - indices = math_ops.range(10, dtype=dtypes.int32) - sample_direct = halton.sample(dim, num_samples=10) - sample_from_indices = halton.sample(dim, sample_indices=indices) - self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), - rtol=1e-6) - - def test_dtypes_works_correctly(self): - with self.test_session(): - dim = 3 - sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) - sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) - self.assertEqual(sample_float32.eval().dtype, np.float32) - self.assertEqual(sample_float64.eval().dtype, np.float64) - - def test_normal_integral_mean_and_var_correctly_estimated(self): - n = int(1000) - # This test is almost identical to the similarly named test in - # monte_carlo_test.py. The only difference is that we use the Halton - # samples instead of the random samples to evaluate the expectations. - # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) - # (N=number of samples). For QMC in low dimensions, the expected convergence - # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the - # 1e6 samples used in the pseudo-random monte carlo. - with self.test_session(): - mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) - mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) - sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) - sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) - p = normal_lib.Normal(loc=mu_p, scale=sigma_p) - q = normal_lib.Normal(loc=mu_q, scale=sigma_q) - - cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) - q_sample = q.quantile(cdf_sample) - - # Compute E_p[X]. - e_x = mc.expectation_importance_sampler( - f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) - - # Compute E_p[X^2]. - e_x2 = mc.expectation_importance_sampler( - f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, - seed=42) - - stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) - # Keep the tolerance levels the same as in monte_carlo_test.py. - self.assertEqual(p.batch_shape, e_x.get_shape()) - self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) - self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) - - def test_docstring_example(self): - # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 - dim = 3 - with self.test_session(): - sample = halton.sample(dim, num_samples=num_samples) - - # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional - # hypercube. - powers = math_ops.range(1.0, limit=dim + 1) - integral = math_ops.reduce_mean( - math_ops.reduce_prod(sample ** powers, axis=-1)) - true_value = 1.0 / math_ops.reduce_prod(powers + 1.0) - - # Produces a relative absolute error of 1.7%. - self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) - - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. - - sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, - dtype=dtypes.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) - - integral_leaped = math_ops.reduce_mean( - math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) - self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py deleted file mode 100644 index 5bd834e56245ab4d874544cfd014fe59ae521ea8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ /dev/null @@ -1,863 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Hamiltonian Monte Carlo.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -import numpy as np -from scipy import stats - -from tensorflow.contrib.bayesflow.python.ops import hmc -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator - -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_linalg_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging_ops - - -def _reduce_variance(x, axis=None, keepdims=False): - sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) - return math_ops.reduce_mean( - math_ops.squared_difference(x, sample_mean), axis, keepdims) - - -class HMCTest(test.TestCase): - - def setUp(self): - self._shape_param = 5. - self._rate_param = 10. - - random_seed.set_random_seed(10003) - np.random.seed(10003) - - def assertAllFinite(self, x): - self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) - - def _log_gamma_log_prob(self, x, event_dims=()): - """Computes log-pdf of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - """ - return math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims) - - def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, - feed_dict=None): - step_size = array_ops.placeholder(np.float32, [], name="step_size") - hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict[hmc_lf_steps] = 1000 - - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - - m = random_ops.random_normal(array_ops.shape(x)) - log_prob_0 = self._log_gamma_log_prob(x, event_dims) - grad_0 = gradients_ops.gradients(log_prob_0, x) - old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) - - new_m, _, log_prob_1, _ = _leapfrog_integrator( - current_momentums=[m], - target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), - current_state_parts=[x], - step_sizes=[step_size], - num_leapfrog_steps=hmc_lf_steps, - current_target_log_prob=log_prob_0, - current_grads_target_log_prob=grad_0) - new_m = new_m[0] - - new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, - event_dims) - - x_shape = sess.run(x, feed_dict).shape - event_size = np.prod(x_shape[independent_chain_ndims:]) - feed_dict[step_size] = 0.1 / event_size - old_energy_, new_energy_ = sess.run([old_energy, new_energy], - feed_dict) - logging_ops.vlog(1, "average energy relative change: {}".format( - (1. - new_energy_ / old_energy_).mean())) - self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) - - def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): - """Tests the long-term energy conservation of the leapfrog integrator. - - The leapfrog integrator is symplectic, so for sufficiently small step - sizes it should be possible to run it more or less indefinitely without - the energy of the system blowing up or collapsing. - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._integrator_conserves_energy(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper(0) - - def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper(1) - - def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper(2) - - def testIntegratorEnergyConservation3(self): - self._integrator_conserves_energy_wrapper(3) - - def testSampleChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - num_results = 10 - independent_chain_ndims = 1 - - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - kwargs = dict( - target_log_prob_fn=log_gamma_log_prob, - current_state=np.random.rand(4, 3, 2), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=150, - seed=52, - ) - - samples0, kernel_results0 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=2 * num_results, - num_steps_between_results=0).items()))) - - samples1, kernel_results1 = hmc.sample_chain( - **dict(list(kwargs.items()) + list(dict( - num_results=num_results, - num_steps_between_results=1).items()))) - - [ - samples0_, - samples1_, - target_log_prob0_, - target_log_prob1_, - ] = sess.run([ - samples0, - samples1, - kernel_results0.current_target_log_prob, - kernel_results1.current_target_log_prob, - ]) - self.assertAllClose(samples0_[::2], samples1_, - atol=1e-5, rtol=1e-5) - self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, - atol=1e-5, rtol=1e-5) - - def _chain_gets_correct_expectations(self, x, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - def log_gamma_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, - array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - num_results = array_ops.placeholder( - np.int32, [], name="num_results") - step_size = array_ops.placeholder( - np.float32, [], name="step_size") - num_leapfrog_steps = array_ops.placeholder( - np.int32, [], name="num_leapfrog_steps") - - if feed_dict is None: - feed_dict = {} - feed_dict.update({num_results: 150, - step_size: 0.05, - num_leapfrog_steps: 2}) - - samples, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_gamma_log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - num_burnin_steps=150, - seed=42) - - self.assertAllEqual(dict(target_calls=2), counter) - - expected_x = (math_ops.digamma(self._shape_param) - - np.log(self._rate_param)) - - expected_exp_x = self._shape_param / self._rate_param - - acceptance_probs_, samples_, expected_x_ = sess.run( - [kernel_results.acceptance_probs, samples, expected_x], - feed_dict) - - actual_x = samples_.mean() - actual_exp_x = np.exp(samples_).mean() - - logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( - expected_x_, expected_exp_x)) - logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( - actual_x, actual_exp_x)) - self.assertNear(actual_x, expected_x_, 2e-2) - self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ > 0.5) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ <= 1.) - - def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - feed_dict = {x_ph: np.random.rand(50, 10, 2)} - self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper(0) - - def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper(1) - - def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper(2) - - def testKernelResultsUsingTruncatedDistribution(self): - def log_prob(x): - return array_ops.where( - x >= 0., - -x - x**2, # Non-constant gradient. - array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) - # This log_prob has the property that it is likely to attract - # the HMC flow toward, and below, zero...but for x <=0, - # log_prob(x) = -inf, which should result in rejection, as well - # as a non-finite log_prob. Thus, this distribution gives us an opportunity - # to test out the kernel results ability to correctly capture rejections due - # to finite AND non-finite reasons. - # Why use a non-constant gradient? This ensures the leapfrog integrator - # will not be exact. - - num_results = 1000 - # Large step size, will give rejections due to integration error in addition - # to rejection due to going into a region of log_prob = -inf. - step_size = 0.1 - num_leapfrog_steps = 5 - num_chains = 2 - - with self.test_session(graph=ops.Graph()) as sess: - - # Start multiple independent chains. - initial_state = ops.convert_to_tensor([0.1] * num_chains) - - states, kernel_results = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=log_prob, - current_state=initial_state, - step_size=step_size, - num_leapfrog_steps=num_leapfrog_steps, - seed=42) - - states_, kernel_results_ = sess.run([states, kernel_results]) - pstates_ = kernel_results_.proposed_state - - neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) - - # First: Test that the mathematical properties of the above log prob - # function in conjunction with HMC show up as expected in kernel_results_. - - # We better have log_prob = -inf some of the time. - self.assertLess(0, neg_inf_mask.sum()) - # We better have some rejections due to something other than -inf. - self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) - # We better have been accepted a decent amount, even near the end of the - # chain, or else this HMC run just got stuck at some point. - self.assertLess( - 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) - # We better not have any NaNs in proposed state or log_prob. - # We may have some NaN in grads, which involve multiplication/addition due - # to gradient rules. This is the known "NaN grad issue with tf.where." - self.assertAllEqual(np.zeros_like(states_), - np.isnan(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual(np.zeros_like(states_), - np.isnan(states_)) - # We better not have any +inf in states, grads, or log_prob. - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_target_log_prob)) - self.assertAllEqual( - np.zeros_like(states_), - np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) - self.assertAllEqual(np.zeros_like(states_), - np.isposinf(states_)) - - # Second: Test that kernel_results is congruent with itself and - # acceptance/rejection of states. - - # Proposed state is negative iff proposed target log prob is -inf. - np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) - np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) - - # Acceptance probs are zero whenever proposed state is negative. - self.assertAllEqual( - np.zeros_like(pstates_[neg_inf_mask]), - kernel_results_.acceptance_probs[neg_inf_mask]) - - # The move is accepted ==> state = proposed state. - self.assertAllEqual( - states_[kernel_results_.is_accepted], - pstates_[kernel_results_.is_accepted], - ) - # The move was rejected <==> state[t] == state[t - 1]. - for t in range(1, num_results): - for i in range(num_chains): - if kernel_results_.is_accepted[t, i]: - self.assertNotEqual(states_[t, i], states_[t - 1, i]) - else: - self.assertEqual(states_[t, i], states_[t - 1, i]) - - def _kernel_leaves_target_invariant(self, initial_draws, - independent_chain_ndims, - sess, feed_dict=None): - def log_gamma_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - def fake_log_prob(x): - """Cooled version of the target distribution.""" - return 1.1 * log_gamma_log_prob(x) - - step_size = array_ops.placeholder(np.float32, [], name="step_size") - - if feed_dict is None: - feed_dict = {} - - feed_dict[step_size] = 0.4 - - sample, kernel_results = hmc.kernel( - target_log_prob_fn=log_gamma_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=43) - - bad_sample, bad_kernel_results = hmc.kernel( - target_log_prob_fn=fake_log_prob, - current_state=initial_draws, - step_size=step_size, - num_leapfrog_steps=5, - seed=44) - - [ - acceptance_probs_, - bad_acceptance_probs_, - initial_draws_, - updated_draws_, - fake_draws_, - ] = sess.run([ - kernel_results.acceptance_probs, - bad_kernel_results.acceptance_probs, - initial_draws, - sample, - bad_sample, - ], feed_dict) - - # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_.mean(), 0.5) - - # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_.mean(), 0.99) - self.assertLess(bad_acceptance_probs_.mean(), 0.99) - - _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), - updated_draws_.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), - fake_draws_.flatten()) - - logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs_.mean())) - logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs_.mean())) - logging_ops.vlog(1, "K-S p-value for true target: {}".format( - ks_p_value_true)) - logging_ops.vlog(1, "K-S p-value for fake target: {}".format( - ks_p_value_fake)) - # Make sure that the MCMC update hasn't changed the empirical CDF much. - self.assertGreater(ks_p_value_true, 1e-3) - # Confirm that targeting the wrong distribution does - # significantly change the empirical CDF. - self.assertLess(ks_p_value_fake, 1e-6) - - def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): - """Tests that the kernel leaves the target distribution invariant. - - Draws some independent samples from the target distribution, - applies an iteration of the MCMC kernel, then runs a - Kolmogorov-Smirnov test to determine if the distribution of the - MCMC-updated samples has changed. - - We also confirm that running the kernel with a different log-pdf - does change the target distribution. (And that we can detect that.) - - Args: - independent_chain_ndims: Python `int` scalar representing the number of - dims associated with independent chains. - """ - with self.test_session(graph=ops.Graph()) as sess: - initial_draws = np.log(np.random.gamma(self._shape_param, - size=[50000, 2, 2])) - initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name="x_ph") - - feed_dict = {x_ph: initial_draws} - - self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, - sess, feed_dict) - - def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper(1) - - def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper(2) - - def testKernelLeavesTargetInvariant3(self): - self._kernel_leaves_target_invariant_wrapper(3) - - def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - - def proposal_log_prob(x): - counter["proposal_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - if feed_dict is None: - feed_dict = {} - - num_steps = 200 - - _, ais_weights, _ = hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal_log_prob, - num_steps=num_steps, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=init, - num_leapfrog_steps=2, - seed=45) - - # We have three calls because the calculation of `ais_weights` entails - # another call to the `convex_combined_log_prob_fn`. We could refactor - # things to avoid this, if needed (eg, b/72994218). - self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) - - event_shape = array_ops.shape(init)[independent_chain_ndims:] - event_size = math_ops.reduce_prod(event_shape) - - log_true_normalizer = ( - -self._shape_param * math_ops.log(self._rate_param) - + math_ops.lgamma(self._shape_param)) - log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) - - log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) - - np.log(num_steps)) - - ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) - ais_weights_size = array_ops.size(ais_weights) - standard_error = math_ops.sqrt( - _reduce_variance(ratio_estimate_true) - / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) - - [ - ratio_estimate_true_, - log_true_normalizer_, - log_estimated_normalizer_, - standard_error_, - ais_weights_size_, - event_size_, - ] = sess.run([ - ratio_estimate_true, - log_true_normalizer, - log_estimated_normalizer, - standard_error, - ais_weights_size, - event_size, - ], feed_dict) - - logging_ops.vlog(1, " log_true_normalizer: {}\n" - " log_estimated_normalizer: {}\n" - " ais_weights_size: {}\n" - " event_size: {}\n".format( - log_true_normalizer_, - log_estimated_normalizer_, - ais_weights_size_, - event_size_)) - self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) - - def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): - """Tests that AIS yields reasonable estimates of normalizers.""" - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - initial_draws = np.random.normal(size=[30, 2, 1]) - self._ais_gets_correct_log_normalizer( - x_ph, - independent_chain_ndims, - sess, - feed_dict={x_ph: initial_draws}) - - def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper(1) - - def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper(2) - - def testAIS3(self): - self._ais_gets_correct_log_normalizer_wrapper(3) - - def testSampleAIChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - independent_chain_ndims = 1 - x = np.random.rand(4, 3, 2) - - def proposal_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - ais_kwargs = dict( - proposal_log_prob_fn=proposal_log_prob, - num_steps=200, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=x, - num_leapfrog_steps=2, - seed=53) - - _, ais_weights0, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - _, ais_weights1, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - [ais_weights0_, ais_weights1_] = sess.run([ - ais_weights0, ais_weights1]) - - self.assertAllClose(ais_weights0_, ais_weights1_, - atol=1e-5, rtol=1e-5) - - def testNanRejection(self): - """Tests that an update that yields NaN potentials gets rejected. - - We run HMC with a target distribution that returns NaN - log-likelihoods if any element of x < 0, and unit-scale - exponential log-likelihoods otherwise. The exponential potential - pushes x towards 0, ensuring that any reasonably large update will - push us over the edge into NaN territory. - """ - def _unbounded_exponential_log_prob(x): - """An exponential distribution with log-likelihood NaN for x < 0.""" - per_element_potentials = array_ops.where( - x < 0., - array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), - -x) - return math_ops.reduce_sum(per_element_potentials) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_unbounded_exponential_log_prob, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=46) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) - - def testNanFromGradsDontPropagate(self): - """Test that update with NaN gradients does not cause NaN in results.""" - def _nan_log_prob_with_nan_gradient(x): - return np.nan * math_ops.reduce_sum(x) - - with self.test_session(graph=ops.Graph()) as sess: - initial_x = math_ops.linspace(0.01, 5, 10) - updated_x, kernel_results = hmc.kernel( - target_log_prob_fn=_nan_log_prob_with_nan_gradient, - current_state=initial_x, - step_size=2., - num_leapfrog_steps=5, - seed=47) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) - - logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) - logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) - - self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) - - self.assertAllFinite( - gradients_ops.gradients(updated_x, initial_x)[0].eval()) - self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, initial_x)]) - self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( - kernel_results.proposed_grads_target_log_prob, - kernel_results.proposed_state)]) - - # Gradients of the acceptance probs and new log prob are not finite. - # self.assertAllFinite( - # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) - # self.assertAllFinite( - # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) - - def _testChainWorksDtype(self, dtype): - with self.test_session(graph=ops.Graph()) as sess: - states, kernel_results = hmc.sample_chain( - num_results=10, - target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), - current_state=np.zeros(5).astype(dtype), - step_size=0.01, - num_leapfrog_steps=10, - seed=48) - states_, acceptance_probs_ = sess.run( - [states, kernel_results.acceptance_probs]) - self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, acceptance_probs_.dtype) - - def testChainWorksIn64Bit(self): - self._testChainWorksDtype(np.float64) - - def testChainWorksIn16Bit(self): - self._testChainWorksDtype(np.float16) - - def testChainWorksCorrelatedMultivariate(self): - dtype = np.float32 - true_mean = dtype([0, 0]) - true_cov = dtype([[1, 0.5], - [0.5, 1]]) - num_results = 2000 - counter = collections.Counter() - with self.test_session(graph=ops.Graph()) as sess: - def target_log_prob(x, y): - counter["target_calls"] += 1 - # Corresponds to unnormalized MVN. - # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) - z = array_ops.stack([x, y], axis=-1) - true_mean - z = array_ops.squeeze( - gen_linalg_ops.matrix_triangular_solve( - np.linalg.cholesky(true_cov), - z[..., array_ops.newaxis]), - axis=-1) - return -0.5 * math_ops.reduce_sum(z**2., axis=-1) - states, _ = hmc.sample_chain( - num_results=num_results, - target_log_prob_fn=target_log_prob, - current_state=[dtype(-2), dtype(2)], - step_size=[0.5, 0.5], - num_leapfrog_steps=2, - num_burnin_steps=200, - num_steps_between_results=1, - seed=54) - self.assertAllEqual(dict(target_calls=2), counter) - states = array_ops.stack(states, axis=-1) - self.assertEqual(num_results, states.shape[0].value) - sample_mean = math_ops.reduce_mean(states, axis=0) - x = states - sample_mean - sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) - [sample_mean_, sample_cov_] = sess.run([ - sample_mean, sample_cov]) - self.assertAllClose(true_mean, sample_mean_, - atol=0.05, rtol=0.) - self.assertAllClose(true_cov, sample_cov_, - atol=0., rtol=0.1) - - -class _EnergyComputationTest(object): - - def testHandlesNanFromPotential(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - target_log_prob, proposed_target_log_prob = [ - self.dtype(x.flatten()) for x in np.meshgrid(x, x)] - num_chains = len(target_log_prob) - dummy_momentums = [-1, 1] - momentums = [self.dtype([dummy_momentums] * num_chains)] - proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - self.assertAllEqual(np.ones_like(grads_).astype(np.bool), - np.isfinite(grads_)) - - def testHandlesNanFromKinetic(self): - with self.test_session(graph=ops.Graph()) as sess: - x = [1, np.inf, -np.inf, np.nan] - momentums, proposed_momentums = [ - [np.reshape(self.dtype(x), [-1, 1])] - for x in np.meshgrid(x, x)] - num_chains = len(momentums[0]) - target_log_prob = np.ones(num_chains, self.dtype) - proposed_target_log_prob = np.ones(num_chains, self.dtype) - - target_log_prob = ops.convert_to_tensor(target_log_prob) - momentums = [ops.convert_to_tensor(momentums[0])] - proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) - proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] - - energy = _compute_energy_change( - target_log_prob, - momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims=1) - grads = gradients_ops.gradients(energy, momentums) - - [actual_energy, grads_] = sess.run([energy, grads]) - - # Ensure energy is `inf` (note: that's positive inf) in weird cases and - # finite otherwise. - expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) - self.assertAllEqual(expected_energy, actual_energy) - - # Ensure gradient is finite. - g = grads_[0].reshape([len(x), len(x)])[:, 0] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) - - # The remaining gradients are nan because the momentum was itself nan or - # inf. - g = grads_[0].reshape([len(x), len(x)])[:, 1:] - self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) - - -class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): - dtype = np.float16 - - -class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): - dtype = np.float32 - - -class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): - dtype = np.float64 - - -class _HMCHandlesLists(object): - - def testStateParts(self): - with self.test_session(graph=ops.Graph()) as sess: - dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) - dist_y = independent_lib.Independent( - gamma_lib.Gamma(concentration=self.dtype([1, 2]), - rate=self.dtype([0.5, 0.75])), - reinterpreted_batch_ndims=1) - def target_log_prob(x, y): - return dist_x.log_prob(x) + dist_y.log_prob(y) - x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] - samples, _ = hmc.sample_chain( - num_results=int(2e3), - target_log_prob_fn=target_log_prob, - current_state=x0, - step_size=0.85, - num_leapfrog_steps=3, - num_burnin_steps=int(250), - seed=49) - actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] - actual_vars = [_reduce_variance(s, axis=0) for s in samples] - expected_means = [dist_x.mean(), dist_y.mean()] - expected_vars = [dist_x.variance(), dist_y.variance()] - [ - actual_means_, - actual_vars_, - expected_means_, - expected_vars_, - ] = sess.run([ - actual_means, - actual_vars, - expected_means, - expected_vars, - ]) - self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) - self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) - - -class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): - dtype = np.float32 - - -class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py deleted file mode 100644 index 750afb6654311fea30a1dc6b31b20aa3b4160ae2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for convolutional Bayesian layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib -from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - - -class Counter(object): - """Helper class to manage incrementing a counting `int`.""" - - def __init__(self): - self._value = -1 - - @property - def value(self): - return self._value - - def __call__(self): - self._value += 1 - return self._value - - -class MockDistribution(independent_lib.Independent): - """Monitors layer calls to the underlying distribution.""" - - def __init__(self, result_sample, result_log_prob, loc=None, scale=None): - self.result_sample = result_sample - self.result_log_prob = result_log_prob - self.result_loc = loc - self.result_scale = scale - self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) - if loc is not None and scale is not None: - self.result_distribution = normal_lib.Normal(loc=self.result_loc, - scale=self.result_scale) - self.called_log_prob = Counter() - self.called_sample = Counter() - self.called_loc = Counter() - self.called_scale = Counter() - - def log_prob(self, *args, **kwargs): - self.called_log_prob() - return self.result_log_prob - - def sample(self, *args, **kwargs): - self.called_sample() - return self.result_sample - - @property - def distribution(self): # for dummy check on Independent(Normal) - return self.result_distribution - - @property - def loc(self): - self.called_loc() - return self.result_loc - - @property - def scale(self): - self.called_scale() - return self.result_scale - - -class MockKLDivergence(object): - """Monitors layer calls to the divergence implementation.""" - - def __init__(self, result): - self.result = result - self.args = [] - self.called = Counter() - - def __call__(self, *args, **kwargs): - self.called() - self.args.append(args) - return self.result - - -class ConvVariational(test.TestCase): - - def _testKLPenaltyKernel(self, layer_class): - with self.test_session(): - layer = layer_class(filters=2, kernel_size=3) - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform([2, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 1) - self.assertListEqual(layer.losses, losses) - - def _testKLPenaltyBoth(self, layer_class): - def _make_normal(dtype, *args): # pylint: disable=unused-argument - return normal_lib.Normal( - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) - with self.test_session(): - layer = layer_class( - filters=2, - kernel_size=3, - bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), - bias_prior_fn=_make_normal) - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform([2, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 2) - self.assertListEqual(layer.losses, losses) - - def _testConvSetUp(self, layer_class, batch_size, depth=None, - height=None, width=None, channels=None, filters=None, - **kwargs): - seed = Counter() - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform( - [batch_size, width, channels], seed=seed()) - kernel_size = (2,) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform( - [batch_size, height, width, channels], seed=seed()) - kernel_size = (2, 2) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform( - [batch_size, depth, height, width, channels], seed=seed()) - kernel_size = (2, 2, 2) - - kernel_shape = kernel_size + (channels, filters) - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_shape, seed=seed()), - scale=random_ops.random_uniform(kernel_shape, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_shape, seed=seed())) - - bias_size = (filters,) - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) - - layer = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence, - **kwargs) - - outputs = layer(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - return (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, - layer, inputs, outputs, kl_penalty, kernel_shape) - - def _testConvReparameterization(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty, kernel_shape) = self._testConvSetUp( - layer_class, batch_size, - depth=depth, height=height, width=width, channels=channels, - filters=filters) - - convolution_op = nn_ops.Convolution( - tensor_shape.TensorShape(inputs.shape), - filter_shape=tensor_shape.TensorShape(kernel_shape), - padding="SAME") - expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) - expected_outputs = nn.bias_add(expected_outputs, - bias_posterior.result_sample, - data_format="NHWC") - - [ - expected_outputs_, actual_outputs_, - expected_kernel_, actual_kernel_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_posterior.result_sample, layer.kernel_posterior_tensor, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_kernel_, actual_kernel_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - kernel_posterior.result_sample]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def _testConvFlipout(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty, kernel_shape) = self._testConvSetUp( - layer_class, batch_size, - depth=depth, height=height, width=width, channels=channels, - filters=filters, seed=44) - - convolution_op = nn_ops.Convolution( - tensor_shape.TensorShape(inputs.shape), - filter_shape=tensor_shape.TensorShape(kernel_shape), - padding="SAME") - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(kernel_posterior.result_loc), - scale=kernel_posterior.result_scale) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - - expected_outputs = convolution_op( - inputs, kernel_posterior.distribution.loc) - - input_shape = array_ops.shape(inputs) - output_shape = array_ops.shape(expected_outputs) - batch_shape = array_ops.expand_dims(input_shape[0], 0) - channels = input_shape[-1] - rank = len(inputs.get_shape()) - 2 - - sign_input = random_ops.random_uniform( - array_ops.concat([batch_shape, - array_ops.expand_dims(channels, 0)], 0), - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=layer.seed) - sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) - sign_output = random_ops.random_uniform( - array_ops.concat([batch_shape, - array_ops.expand_dims(filters, 0)], 0), - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=distribution_util.gen_new_seed( - layer.seed, salt="conv_flipout")) - sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) - for _ in range(rank): - sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) - sign_output = array_ops.expand_dims(sign_output, 1) - - sign_input = array_ops.tile( # tile for element-wise op broadcasting - sign_input, - [1] + [input_shape[i + 1] for i in range(rank)] + [1]) - sign_output = array_ops.tile( - sign_output, - [1] + [output_shape[i + 1] for i in range(rank)] + [1]) - - perturbed_inputs = convolution_op( - inputs * sign_input, expected_kernel_posterior_affine_tensor) - perturbed_inputs *= sign_output - - expected_outputs += perturbed_inputs - expected_outputs = nn.bias_add(expected_outputs, - bias_posterior.result_sample, - data_format="NHWC") - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, kernel_prior.distribution, None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def _testRandomConvFlipout(self, layer_class): - batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 - with self.test_session() as sess: - seed = Counter() - if layer_class in (prob_layers_lib.Conv1DReparameterization, - prob_layers_lib.Conv1DFlipout): - inputs = random_ops.random_uniform( - [batch_size, width, channels], seed=seed()) - kernel_size = (2,) - elif layer_class in (prob_layers_lib.Conv2DReparameterization, - prob_layers_lib.Conv2DFlipout): - inputs = random_ops.random_uniform( - [batch_size, height, width, channels], seed=seed()) - kernel_size = (2, 2) - elif layer_class in (prob_layers_lib.Conv3DReparameterization, - prob_layers_lib.Conv3DFlipout): - inputs = random_ops.random_uniform( - [batch_size, depth, height, width, channels], seed=seed()) - kernel_size = (2, 2, 2) - - kernel_shape = kernel_size + (channels, filters) - bias_size = (filters,) - - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform( - kernel_shape, seed=seed()), - scale=random_ops.random_uniform( - kernel_shape, seed=seed()), - result_log_prob=random_ops.random_uniform( - kernel_shape, seed=seed()), - result_sample=random_ops.random_uniform( - kernel_shape, seed=seed())) - bias_posterior = MockDistribution( - loc=random_ops.random_uniform( - bias_size, seed=seed()), - scale=random_ops.random_uniform( - bias_size, seed=seed()), - result_log_prob=random_ops.random_uniform( - bias_size, seed=seed()), - result_sample=random_ops.random_uniform( - bias_size, seed=seed())) - layer_one = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=44) - layer_two = layer_class( - filters=filters, - kernel_size=kernel_size, - padding="SAME", - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=45) - - outputs_one = layer_one(inputs) - outputs_two = layer_two(inputs) - - outputs_one_, outputs_two_ = sess.run([ - outputs_one, outputs_two]) - - self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), - np.prod(outputs_one_.shape)) - - def testKLPenaltyKernelConv1DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization) - - def testKLPenaltyKernelConv2DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization) - - def testKLPenaltyKernelConv3DReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization) - - def testKLPenaltyKernelConv1DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout) - - def testKLPenaltyKernelConv2DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout) - - def testKLPenaltyKernelConv3DFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout) - - def testKLPenaltyBothConv1DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization) - - def testKLPenaltyBothConv2DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization) - - def testKLPenaltyBothConv3DReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization) - - def testKLPenaltyBothConv1DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout) - - def testKLPenaltyBothConv2DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout) - - def testKLPenaltyBothConv3DFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout) - - def testConv1DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization) - - def testConv2DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization) - - def testConv3DReparameterization(self): - self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization) - - def testConv1DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv1DFlipout) - - def testConv2DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv2DFlipout) - - def testConv3DFlipout(self): - self._testConvFlipout(prob_layers_lib.Conv3DFlipout) - - def testRandomConv1DFlipout(self): - self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py deleted file mode 100644 index 342f38ccec7ec74db1b393d6cdc22300205cc547..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for dense Bayesian layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational as prob_layers_lib -from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - - -class Counter(object): - """Helper class to manage incrementing a counting `int`.""" - - def __init__(self): - self._value = -1 - - @property - def value(self): - return self._value - - def __call__(self): - self._value += 1 - return self._value - - -class MockDistribution(independent_lib.Independent): - """Monitors layer calls to the underlying distribution.""" - - def __init__(self, result_sample, result_log_prob, loc=None, scale=None): - self.result_sample = result_sample - self.result_log_prob = result_log_prob - self.result_loc = loc - self.result_scale = scale - self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) - if loc is not None and scale is not None: - self.result_distribution = normal_lib.Normal(loc=self.result_loc, - scale=self.result_scale) - self.called_log_prob = Counter() - self.called_sample = Counter() - self.called_loc = Counter() - self.called_scale = Counter() - - def log_prob(self, *args, **kwargs): - self.called_log_prob() - return self.result_log_prob - - def sample(self, *args, **kwargs): - self.called_sample() - return self.result_sample - - @property - def distribution(self): # for dummy check on Independent(Normal) - return self.result_distribution - - @property - def loc(self): - self.called_loc() - return self.result_loc - - @property - def scale(self): - self.called_scale() - return self.result_scale - - -class MockKLDivergence(object): - """Monitors layer calls to the divergence implementation.""" - - def __init__(self, result): - self.result = result - self.args = [] - self.called = Counter() - - def __call__(self, *args, **kwargs): - self.called() - self.args.append(args) - return self.result - - -class DenseVariational(test.TestCase): - - def _testKLPenaltyKernel(self, layer_class): - with self.test_session(): - layer = layer_class(units=2) - inputs = random_ops.random_uniform([2, 3], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 1) - self.assertListEqual(layer.losses, losses) - - def _testKLPenaltyBoth(self, layer_class): - def _make_normal(dtype, *args): # pylint: disable=unused-argument - return normal_lib.Normal( - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) - with self.test_session(): - layer = layer_class( - units=2, - bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), - bias_prior_fn=_make_normal) - inputs = random_ops.random_uniform([2, 3], seed=1) - - # No keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 0) - self.assertListEqual(layer.losses, losses) - - _ = layer(inputs) - - # Yes keys. - losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(losses), 2) - self.assertListEqual(layer.losses, losses) - - def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size, - **kwargs): - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_size, seed=seed()), - scale=random_ops.random_uniform(kernel_size, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) - - layer = layer_class( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence, - **kwargs) - - outputs = layer(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - return (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, - layer, inputs, outputs, kl_penalty) - - def testKLPenaltyKernelReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseReparameterization) - - def testKLPenaltyKernelLocalReparameterization(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseLocalReparameterization) - - def testKLPenaltyKernelFlipout(self): - self._testKLPenaltyKernel(prob_layers_lib.DenseFlipout) - - def testKLPenaltyBothReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseReparameterization) - - def testKLPenaltyBothLocalReparameterization(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseLocalReparameterization) - - def testKLPenaltyBothFlipout(self): - self._testKLPenaltyBoth(prob_layers_lib.DenseFlipout) - - def testDenseReparameterization(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseReparameterization, - batch_size, in_size, out_size) - - expected_outputs = ( - math_ops.matmul(inputs, kernel_posterior.result_sample) + - bias_posterior.result_sample) - - [ - expected_outputs_, actual_outputs_, - expected_kernel_, actual_kernel_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_posterior.result_sample, layer.kernel_posterior_tensor, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_kernel_, actual_kernel_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - kernel_posterior.result_sample]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testDenseLocalReparameterization(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseLocalReparameterization, - batch_size, in_size, out_size) - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=math_ops.matmul(inputs, kernel_posterior.result_loc), - scale=math_ops.matmul( - inputs**2., kernel_posterior.result_scale**2)**0.5) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - expected_outputs = (expected_kernel_posterior_affine_tensor + - bias_posterior.result_sample) - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, - kernel_prior.distribution, - None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testDenseFlipout(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - (kernel_posterior, kernel_prior, kernel_divergence, - bias_posterior, bias_prior, bias_divergence, layer, inputs, - outputs, kl_penalty) = self._testDenseSetUp( - prob_layers_lib.DenseFlipout, - batch_size, in_size, out_size, seed=44) - - expected_kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(kernel_posterior.result_loc), - scale=kernel_posterior.result_scale) - expected_kernel_posterior_affine_tensor = ( - expected_kernel_posterior_affine.sample(seed=42)) - - sign_input = random_ops.random_uniform( - [batch_size, in_size], - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=layer.seed) - sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) - sign_output = random_ops.random_uniform( - [batch_size, out_size], - minval=0, - maxval=2, - dtype=dtypes.int32, - seed=distribution_util.gen_new_seed( - layer.seed, salt="dense_flipout")) - sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) - perturbed_inputs = math_ops.matmul( - inputs * sign_input, expected_kernel_posterior_affine_tensor) - perturbed_inputs *= sign_output - - expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc) - expected_outputs += perturbed_inputs - expected_outputs += bias_posterior.result_sample - - [ - expected_outputs_, actual_outputs_, - expected_kernel_divergence_, actual_kernel_divergence_, - expected_bias_, actual_bias_, - expected_bias_divergence_, actual_bias_divergence_, - ] = sess.run([ - expected_outputs, outputs, - kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, layer.bias_posterior_tensor, - bias_divergence.result, kl_penalty[1], - ]) - - self.assertAllClose( - expected_bias_, actual_bias_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_outputs_, actual_outputs_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_kernel_divergence_, actual_kernel_divergence_, - rtol=1e-6, atol=0.) - self.assertAllClose( - expected_bias_divergence_, actual_bias_divergence_, - rtol=1e-6, atol=0.) - - self.assertAllEqual( - [[kernel_posterior.distribution, kernel_prior.distribution, None]], - kernel_divergence.args) - - self.assertAllEqual( - [[bias_posterior.distribution, - bias_prior.distribution, - bias_posterior.result_sample]], - bias_divergence.args) - - def testRandomDenseFlipout(self): - batch_size, in_size, out_size = 2, 3, 4 - with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - scale=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - result_log_prob=random_ops.random_uniform( - [in_size, out_size], seed=seed()), - result_sample=random_ops.random_uniform( - [in_size, out_size], seed=seed())) - bias_posterior = MockDistribution( - loc=random_ops.random_uniform( - [out_size], seed=seed()), - scale=random_ops.random_uniform( - [out_size], seed=seed()), - result_log_prob=random_ops.random_uniform( - [out_size], seed=seed()), - result_sample=random_ops.random_uniform( - [out_size], seed=seed())) - layer_one = prob_layers_lib.DenseFlipout( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=44) - layer_two = prob_layers_lib.DenseFlipout( - units=out_size, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - seed=45) - - outputs_one = layer_one(inputs) - outputs_two = layer_two(inputs) - - outputs_one_, outputs_two_ = sess.run([ - outputs_one, outputs_two]) - - self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), out_size) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py deleted file mode 100644 index 52e36e135d95c1ec919c710f35d59073c2134d05..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for MCMC diagnostic utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import spectral_ops_test_util -from tensorflow.python.platform import test - -rng = np.random.RandomState(42) - - -class _EffectiveSampleSizeTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to implement `use_static_shape`.") - - def _check_versus_expected_effective_sample_size(self, - x_, - expected_ess, - sess, - atol=1e-2, - rtol=1e-2, - filter_threshold=None, - filter_beyond_lag=None): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - ess = mcmc_diagnostics.effective_sample_size( - x, - filter_threshold=filter_threshold, - filter_beyond_lag=filter_beyond_lag) - if self.use_static_shape: - self.assertAllEqual(x.shape[1:], ess.shape) - - ess_ = sess.run(ess) - - self.assertAllClose( - np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) - - def testIidRank1NormalHasFullEssMaxLags10(self): - # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we - # should have a good estimate of ESS, and it should be close to the full - # sequence length of 5000. - # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only - # since we know the correlation length should be zero right away. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=10, - filter_threshold=None, - rtol=0.3) - - def testIidRank2NormalHasFullEssMaxLags10(self): - # See similar test for Rank1Normal for reasoning. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000, 2).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=10, - filter_threshold=None, - rtol=0.3) - - def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): - # With a length 5000 iid normal sequence, and filter_threshold = 0, - # we should have a super-duper estimate of ESS, and it should be very close - # to the full sequence length of 5000. - # The choice of filter_beyond_lag = 0 means we cutoff as soon as the - # auto-corris below zero. This should happen very quickly, due to the fact - # that the theoretical auto-corr is [1, 0, 0,...] - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): - # See similar test for Rank1Normal for reasoning. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=rng.randn(5000, 2).astype(np.float32), - expected_ess=5000, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): - # Create x_, such that - # x_[i] = iid_x_[0], i = 0,...,9 - # x_[i] = iid_x_[1], i = 10,..., 19, - # and so on. - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=x_, - expected_ess=50000 // 10, - sess=sess, - filter_beyond_lag=50, - filter_threshold=None, - rtol=0.2) - - def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( - self): - # Create x_, such that - # x_[i] = iid_x_[0], i = 0,...,9 - # x_[i] = iid_x_[1], i = 10,..., 19, - # and so on. - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - self._check_versus_expected_effective_sample_size( - x_=x_, - expected_ess=50000 // 10, - sess=sess, - filter_beyond_lag=None, - filter_threshold=0., - rtol=0.1) - - def testListArgs(self): - # x_ has correlation length 10 ==> ESS = N / 10 - # y_ has correlation length 1 ==> ESS = N - iid_x_ = rng.randn(5000, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) - y_ = rng.randn(50000).astype(np.float32) - states = [x_, x_, y_, y_] - filter_threshold = [0., None, 0., None] - filter_beyond_lag = [None, 5, None, 5] - - # See other tests for reasoning on tolerance. - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - ess = mcmc_diagnostics.effective_sample_size( - states, - filter_threshold=filter_threshold, - filter_beyond_lag=filter_beyond_lag) - ess_ = sess.run(ess) - self.assertAllEqual(4, len(ess_)) - - self.assertAllClose(50000 // 10, ess_[0], rtol=0.3) - self.assertAllClose(50000 // 10, ess_[1], rtol=0.3) - self.assertAllClose(50000, ess_[2], rtol=0.1) - self.assertAllClose(50000, ess_[3], rtol=0.1) - - def testMaxLagsThresholdLessThanNeg1SameAsNone(self): - # Setting both means we filter out items R_k from the auto-correlation - # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. - - # x_ has correlation length 10. - iid_x_ = rng.randn(500, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - - ess_none_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=None, filter_beyond_lag=None) - ess_none_200 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=None, filter_beyond_lag=200) - ess_neg2_200 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=-2., filter_beyond_lag=200) - ess_neg2_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=-2., filter_beyond_lag=None) - ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run( - [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none]) - - # filter_threshold=-2 <==> filter_threshold=None. - self.assertAllClose(ess_none_none_, ess_neg2_none_) - self.assertAllClose(ess_none_200_, ess_neg2_200_) - - def testMaxLagsArgsAddInAnOrManner(self): - # Setting both means we filter out items R_k from the auto-correlation - # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. - - # x_ has correlation length 10. - iid_x_ = rng.randn(500, 1).astype(np.float32) - x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) - with self.test_session() as sess: - with spectral_ops_test_util.fft_kernel_label_map(): - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - - ess_1_9 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=9) - ess_1_none = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=None) - ess_none_9 = mcmc_diagnostics.effective_sample_size( - x, filter_threshold=1., filter_beyond_lag=9) - ess_1_9_, ess_1_none_, ess_none_9_ = sess.run( - [ess_1_9, ess_1_none, ess_none_9]) - - # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10, - # filter_threshold = 1 <==> filter_beyond_lag = 9. - self.assertAllClose(ess_1_9_, ess_1_none_) - self.assertAllClose(ess_1_9_, ess_none_9_) - - -class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): - - @property - def use_static_shape(self): - return True - - -class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest): - - @property - def use_static_shape(self): - return False - - -class _PotentialScaleReductionTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to impliment `use_static_shape`.") - - def testListOfStatesWhereFirstPassesSecondFails(self): - """Simple test showing API with two states. Read first!.""" - n_samples = 1000 - - # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass. - state_0 = rng.randn(n_samples, 2) - - # state_1 is three 4-variate chains taken from Normal(0, 1) that have been - # shifted. Since every chain is shifted, they are not the same, and the - # test should fail. - offset = np.array([1., -1., 2.]).reshape(3, 1) - state_1 = rng.randn(n_samples, 3, 4) + offset - - rhat = mcmc_diagnostics.potential_scale_reduction( - chains_states=[state_0, state_1], independent_chain_ndims=1) - - self.assertIsInstance(rhat, list) - with self.test_session() as sess: - rhat_0_, rhat_1_ = sess.run(rhat) - - # r_hat_0 should be close to 1, meaning test is passed. - self.assertAllEqual((), rhat_0_.shape) - self.assertAllClose(1., rhat_0_, rtol=0.02) - - # r_hat_1 should be greater than 1.2, meaning test has failed. - self.assertAllEqual((4,), rhat_1_.shape) - self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2) - - def check_results(self, state_, independent_chain_shape, should_pass): - sample_ndims = 1 - independent_chain_ndims = len(independent_chain_shape) - with self.test_session(): - state = array_ops.placeholder_with_default( - input=state_, shape=state_.shape if self.use_static_shape else None) - - rhat = mcmc_diagnostics.potential_scale_reduction( - state, independent_chain_ndims=independent_chain_ndims) - - if self.use_static_shape: - self.assertAllEqual( - state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape) - - rhat_ = rhat.eval() - if should_pass: - self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02) - else: - self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2) - - def iid_normal_chains_should_pass_wrapper(self, - sample_shape, - independent_chain_shape, - other_shape, - dtype=np.float32): - """Check results with iid normal chains.""" - - state_shape = sample_shape + independent_chain_shape + other_shape - state_ = rng.randn(*state_shape).astype(dtype) - - # The "other" dimensions do not have to be identical, just independent, so - # force them to not be identical. - if other_shape: - state_ *= rng.rand(*other_shape).astype(dtype) - - self.check_results(state_, independent_chain_shape, should_pass=True) - - def testPassingIIDNdimsAreIndependentOneOtherZero(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[4], other_shape=[]) - - def testPassingIIDNdimsAreIndependentOneOtherOne(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[3], other_shape=[7]) - - def testPassingIIDNdimsAreIndependentOneOtherTwo(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7]) - - def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self): - self.iid_normal_chains_should_pass_wrapper( - sample_shape=[10000], - independent_chain_shape=[2, 3], - other_shape=[5, 7], - dtype=np.float64) - - def offset_normal_chains_should_fail_wrapper( - self, sample_shape, independent_chain_shape, other_shape): - """Check results with normal chains that are offset from each other.""" - - state_shape = sample_shape + independent_chain_shape + other_shape - state_ = rng.randn(*state_shape) - - # Add a significant offset to the different (formerly iid) chains. - offset = np.linspace( - 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len( - sample_shape) + independent_chain_shape + [1] * len(other_shape)) - state_ += offset - - self.check_results(state_, independent_chain_shape, should_pass=False) - - def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self): - self.offset_normal_chains_should_fail_wrapper( - sample_shape=[10000], independent_chain_shape=[2], other_shape=[5]) - - -class PotentialScaleReductionStaticTest(test.TestCase, - _PotentialScaleReductionTest): - - @property - def use_static_shape(self): - return True - - def testIndependentNdimsLessThanOneRaises(self): - with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"): - mcmc_diagnostics.potential_scale_reduction( - rng.rand(2, 3, 4), independent_chain_ndims=0) - - -class PotentialScaleReductionDynamicTest(test.TestCase, - _PotentialScaleReductionTest): - - @property - def use_static_shape(self): - return False - - -class _ReduceVarianceTest(object): - - @property - def use_static_shape(self): - raise NotImplementedError( - "Subclass failed to impliment `use_static_shape`.") - - def check_versus_numpy(self, x_, axis, biased, keepdims): - with self.test_session(): - x_ = np.asarray(x_) - x = array_ops.placeholder_with_default( - input=x_, shape=x_.shape if self.use_static_shape else None) - var = mcmc_diagnostics._reduce_variance( - x, axis=axis, biased=biased, keepdims=keepdims) - np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims) - - if self.use_static_shape: - self.assertAllEqual(np_var.shape, var.shape) - - var_ = var.eval() - # We will mask below, which changes shape, so check shape explicitly here. - self.assertAllEqual(np_var.shape, var_.shape) - - # We get NaN when we divide by zero due to the size being the same as ddof - nan_mask = np.isnan(np_var) - if nan_mask.any(): - self.assertTrue(np.isnan(var_[nan_mask]).all()) - self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02) - - def testScalarBiasedTrue(self): - self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False) - - def testScalarBiasedFalse(self): - # This should result in NaN. - self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False) - - def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False) - - def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True) - - def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True) - - def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self): - self.check_versus_numpy( - x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False) - - -class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest): - - @property - def use_static_shape(self): - return True - - -class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest): - - @property - def use_static_shape(self): - return False - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py deleted file mode 100644 index 63d93fad64d077aa385b72428665e841b6784b90..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for metropolis_hastings.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class McmcStepTest(test.TestCase): - - def test_density_increasing_step_accepted(self): - """Tests that if a transition increases density, it is always accepted.""" - target_log_density = lambda x: - x * x - state = variable_scope.get_variable('state', initializer=10.) - state_log_density = variable_scope.get_variable( - 'state_log_density', - initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) - - get_next_proposal = lambda x: (x - 1., None) - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, get_next_proposal, seed=1234) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - for j in range(9): - sess.run(step) - sample = sess.run(state) - sample_log_density = sess.run(state_log_density) - self.assertAlmostEqual(sample, 9 - j) - self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) - - def test_sample_properties(self): - """Tests that the samples converge to the target distribution.""" - - def target_log_density(x): - """Log-density corresponding to a normal distribution with mean = 4.""" - return - (x - 2.0) * (x - 2.0) * 0.5 - - # Use the uniform random walker to generate proposals. - proposal_fn = mh.uniform_random_proposal( - step_size=1.0, seed=1234) - - state = variable_scope.get_variable('state', initializer=0.0) - state_log_density = variable_scope.get_variable( - 'state_log_density', - initializer=target_log_density(state.initialized_value())) - - log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) - # Random walk MCMC converges slowly so need to put in enough iterations. - num_iterations = 5000 - step = mh.evolve(state, state_log_density, log_accept_ratio, - target_log_density, proposal_fn, seed=4321) - - init = variables.global_variables_initializer() - - sample_sum, sample_sq_sum = 0.0, 0.0 - with self.test_session() as sess: - sess.run(init) - for _ in np.arange(num_iterations): - # Allow for the mixing of the chain and discard these samples. - sess.run(step) - for _ in np.arange(num_iterations): - sess.run(step) - sample = sess.run(state) - sample_sum += sample - sample_sq_sum += sample * sample - - sample_mean = sample_sum / num_iterations - sample_variance = sample_sq_sum / num_iterations - sample_mean * sample_mean - # The samples have large autocorrelation which reduces the effective sample - # size. - self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) - self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) - - def test_normal_proposals(self): - """Tests that the normal proposals are correctly distributed.""" - - initial_points = array_ops.ones([10000], dtype=dtypes.float32) - proposal_fn = mh.normal_random_proposal( - scale=2.0, seed=1234) - proposal_points, _ = proposal_fn(initial_points) - - with self.test_session() as sess: - sample = sess.run(proposal_points) - - # It is expected that the elements in proposal_points have the same mean as - # initial_points and have the standard deviation that was supplied to the - # proposal scheme. - self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) - self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) - - def test_docstring_example(self): - """Tests the simplified docstring example with multiple chains.""" - - n = 2 # dimension of the problem - - # Generate 300 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = variable_scope.get_variable( - 'state', initializer=random_ops.random_normal( - [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - math_ops.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = variable_scope.get_variable( - 'state_log_density', - initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = variable_scope.get_variable( - 'log_acceptance_ratio', - initializer=array_ops.zeros([300], dtype=dtypes.float32)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + random_ops.random_uniform( - array_ops.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = variables.initialize_all_variables() - with self.test_session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps. - for _ in range(10): - sess.run(stepper) - samples = sess.run(state) - covariance = np.eye(n) - # Verify that the estimated mean and covariance are close to the true - # values. - self.assertAlmostEqual( - np.max(np.abs(np.mean(samples, 0) - - np.zeros(n))), 0, - delta=0.1) - self.assertAlmostEqual( - np.max(np.abs(np.reshape(np.cov(samples, rowvar=False), [n**2]) - - np.reshape(covariance, [n**2]))), 0, - delta=0.2) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py deleted file mode 100644 index 756c25683bd4b0c8c77e9e28485ca2a85582999c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional test for GradientDescent.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import math -from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class SGLDOptimizerTest(test.TestCase): - - def testBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.53 - sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - - def testBasicMultiInstance(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - vara = variables.Variable([1.1, 2.1], dtype=dtype) - varb = variables.Variable([3.0, 4.0], dtype=dtype) - gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) - gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.5 - sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - sgd_optimizer2 = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate) - sgd_op2 = sgd_optimizer2.apply_gradients( - zip([gradsa, gradsb], [vara, varb])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) - - # Run 1 step of sgd - sgd_op.run() - sgd_op2.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) - - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) - self.assertNotEqual(sgd_optimizer.variable_scope, - sgd_optimizer2.variable_scope) - self.assertNotEqual(sgd_optimizer.variable_scope.name, - sgd_optimizer2.variable_scope.name) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer2._counter.eval()) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = constant_op.constant(3.0) - decay_rate = 0.5 - sgd_op = SGLDOptimizer( - lrate, preconditioner_decay_rate=constant_op.constant( - decay_rate)).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - - def testGradWrtRef(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - opt = SGLDOptimizer(3.0) - values = [1.0, 3.0] - vars_ = [variables.Variable([v], dtype=dtype) for v in values] - grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) - variables.global_variables_initializer().run() - for grad, _ in grads_and_vars: - self.assertAllCloseAccordingToType([1.0], grad.eval()) - - def testWithGlobalStep(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - global_step = variables.Variable(0, trainable=False) - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.1 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - - # Validate updated params and global_step - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) - - def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) - var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) - grads0 = ops.IndexedSlices( - constant_op.constant([0.1], shape=[1, 1], dtype=dtype), - constant_op.constant([0]), constant_op.constant([2, 1])) - grads1 = ops.IndexedSlices( - constant_op.constant([0.01], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), constant_op.constant([2, 1])) - decay_rate = 0.9 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) - self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + - (1 - decay_rate) * 0.1**2 + 1e-8)) - self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], - var0.eval()) - grads_scaled = (0.5 * 0.01 / math.sqrt( - decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) - self.assertAllCloseAccordingToType( - [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py deleted file mode 100644 index f978cf86417dc5ff5412a3eee584330a266e0964..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for utility functions related to managing `tf.Variable`s.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import variable_utils - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.ops import variables as variables_ops -from tensorflow.python.platform import test - - -def test_fn(x): - x = ops.convert_to_tensor(x, name="x") - dtype = x.dtype.as_numpy_dtype - s = x.shape.as_list() - z = varscope_ops.get_variable( - name="z", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) - y = varscope_ops.get_variable( - name="y", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2) - return x + y + z - - -class _WrapCallableTest(object): - - def testDefaultArgsWorkCorrectly(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( - test_fn, [x]) - - varscope_ops.get_variable_scope().reuse_variables() - - result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5) - - y_actual = varscope_ops.get_variable("y", dtype=self.dtype) - z_actual = varscope_ops.get_variable("z", dtype=self.dtype) - - variables_ops.global_variables_initializer().run() - result_ = result.eval() - - self.assertEqual(self.dtype, result_.dtype) - self.assertAllEqual([5.5, 6.5, 7.5], result_) - self.assertAllEqual([y_actual, z_actual], vars_args) - - def testNonDefaultArgsWorkCorrectly(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - - _ = test_fn(self.dtype([0., 0.])) # Needed to create vars. - varscope_ops.get_variable_scope().reuse_variables() - - y_actual = varscope_ops.get_variable("y", dtype=self.dtype) - - wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( - test_fn, [x], possible_ancestor_vars=[y_actual]) - - result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y - - variables_ops.global_variables_initializer().run() - result_ = result.eval() - - self.assertEqual(self.dtype, result_.dtype) - self.assertAllEqual([2.5, 4.5], result_) - self.assertAllEqual([y_actual], vars_args) - - def testWarnings(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, _ = variable_utils.externalize_variables_as_args( - test_fn, [x], possible_ancestor_vars=[]) - varscope_ops.get_variable_scope().reuse_variables() - with warnings.catch_warnings(record=True) as w: - wrapped_fn(self.dtype(2)) - w = sorted(w, key=lambda w: str(w.message)) - self.assertEqual(2, len(w)) - self.assertRegexpMatches( - str(w[0].message), - r"Variable .* 'y:0' .* not found in bypass dict.") - self.assertRegexpMatches( - str(w[1].message), - r"Variable .* 'z:0' .* not found in bypass dict.") - - def testExceptions(self): - with self.test_session(): - x = constant_op.constant(self.dtype([0.1, 0.2])) - wrapped_fn, _ = variable_utils.externalize_variables_as_args( - test_fn, - [x], - possible_ancestor_vars=[], - assert_variable_override=True) - varscope_ops.get_variable_scope().reuse_variables() - with self.assertRaisesRegexp(ValueError, r"not found"): - wrapped_fn(self.dtype(2)) - - -class WrapCallableTest16(test.TestCase, _WrapCallableTest): - dtype = np.float16 - - -class WrapCallableTest32(test.TestCase, _WrapCallableTest): - dtype = np.float32 - - -class WrapCallableTest64(test.TestCase, _WrapCallableTest): - dtype = np.float64 - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py deleted file mode 100644 index 83c64dbe0fd586edcb784a5c09a4c133aaa99cff..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functional test for GradientDescent.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.contrib.bayesflow.python.ops.optimizers import VariationalSGDOptimizer -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class VariationalSGDOptimizerTest(test.TestCase): - - def testBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.53 - sgd_op = VariationalSGDOptimizer( - 1, - 1, - preconditioner_decay_rate=decay_rate, - max_learning_rate=3.0, - burnin_max_learning_rate=3.0, - use_single_learning_rate=True).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - - def testBasicMultiInstance(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - vara = variables.Variable([1.1, 2.1], dtype=dtype) - varb = variables.Variable([3.0, 4.0], dtype=dtype) - gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) - gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=1.0, - burnin_max_learning_rate=3.0, - preconditioner_decay_rate=decay_rate) - sgd_op = optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1])) - optimizer2 = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=1.0, - burnin_max_learning_rate=10.0, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op2 = optimizer2.apply_gradients( - zip([gradsa, gradsb], [vara, varb])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) - - # Run 1 step of sgd - sgd_op.run() - sgd_op2.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3. * 0.1, 2.1 - 3. * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([1.1 - 0.1, 2.1 - 0.1], vara.eval()) - - self.assertAllCloseAccordingToType([3.0 - 3. * 0.01, 4.0 - 3. * 0.01], - var1.eval()) - self.assertAllCloseAccordingToType([3.0 - 0.01, 4.0 - 0.01], - varb.eval()) - self.assertNotEqual(optimizer.variable_scope, - optimizer2.variable_scope) - self.assertNotEqual(optimizer.variable_scope.name, - optimizer2.variable_scope.name) - self.assertAllCloseAccordingToType(1, optimizer._counter.eval()) - self.assertAllCloseAccordingToType(1, optimizer2._counter.eval()) - - def testTensorLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = constant_op.constant(3.0) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - sgd_op = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=lrate, - burnin=0, - preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - - def testTensorDecayLearningRate(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - lrate = variables.Variable(3.0) - lrate_decay_op = lrate.assign_add(-3.) - decay_rate = 0.5 - batch_size = 2 - total_num_examples = 10 - optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=lrate, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - # Update learning rate to 0 - lrate_decay_op.eval() - sgd_op.run() - # Validate params haven't changed - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - lrate_decay_op.eval() - - with self.assertRaises(errors.InvalidArgumentError): - sgd_op.run() - - def testGradWrtRef(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - opt = VariationalSGDOptimizer(1, 1, max_learning_rate=1.0) - values = [1.0, 3.0] - vars_ = [variables.Variable([v], dtype=dtype) for v in values] - grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) - variables.global_variables_initializer().run() - for grad, _ in grads_and_vars: - self.assertAllCloseAccordingToType([1.0], grad.eval()) - - def testWithGlobalStep(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - global_step = variables.Variable(0, trainable=False) - var0 = variables.Variable([1.1, 2.1], dtype=dtype) - var1 = variables.Variable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - decay_rate = 0.1 - batch_size = 2 - total_num_examples = 10 - sgd_optimizer = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=3.0, - burnin=0, - preconditioner_decay_rate=decay_rate) - sgd_op = sgd_optimizer.apply_gradients( - zip([grads0, grads1], [var0, var1]), global_step=global_step) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) - self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - - # Validate updated params and global_step - self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], - var0.eval()) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - var1.eval()) - self.assertAllCloseAccordingToType(1, global_step.eval()) - self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) - - def testSparseBasic(self): - for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): - var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) - var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) - grads0 = ops.IndexedSlices( - constant_op.constant([0.1], shape=[1, 1], dtype=dtype), - constant_op.constant([0]), constant_op.constant([2, 1])) - grads1 = ops.IndexedSlices( - constant_op.constant([0.01], shape=[1, 1], dtype=dtype), - constant_op.constant([1]), constant_op.constant([2, 1])) - decay_rate = 0.1 - batch_size = 2 - total_num_examples = 10 - sgd_op = VariationalSGDOptimizer( - batch_size, - total_num_examples, - max_learning_rate=3.0, - burnin=0, - preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) - variables.global_variables_initializer().run() - # Fetch params to validate initial values - self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) - self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) - # Run 1 step of sgd - sgd_op.run() - # Validate updated params - self.assertAllCloseAccordingToType([[1.1 - 3.0 * 0.1], [2.1]], - var0.eval()) - self.assertAllCloseAccordingToType( - [[3.0 - 3.0 * 0], [4.0 - 3.0 * 0.01]], var1.eval()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py deleted file mode 100644 index 8efd59d6516924bea538717d45bb4ae303583421..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py +++ /dev/null @@ -1,1105 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Csiszar f-Divergence and helpers. - -@@amari_alpha -@@arithmetic_geometric -@@chi_square -@@csiszar_vimco -@@dual_csiszar_function -@@jeffreys -@@jensen_shannon -@@kl_forward -@@kl_reverse -@@log1p_abs -@@modified_gan -@@monte_carlo_csiszar_f_divergence -@@pearson -@@squared_hellinger -@@symmetrized_csiszar_function -@@total_variation -@@triangular - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util - - -def amari_alpha(logu, alpha=1., self_normalized=False, name=None): - """The Amari-alpha Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Amari-alpha Csiszar-function is: - - ```none - f(u) = { -log(u) + (u - 1), alpha = 0 - { u log(u) - (u - 1), alpha = 1 - { [(u**alpha - 1) - alpha (u - 1)] / (alpha (alpha - 1)), otherwise - ``` - - When `self_normalized = False` the `(u - 1)` terms are omitted. - - Warning: when `alpha != 0` and/or `self_normalized = True` this function makes - non-log-space calculations and may therefore be numerically unstable for - `|logu| >> 0`. - - For more information, see: - A. Cichocki and S. Amari. "Families of Alpha-Beta-and GammaDivergences: - Flexible and Robust Measures of Similarities." Entropy, vol. 12, no. 6, pp. - 1532-1568, 2010. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - alpha: `float`-like Python scalar. (See Mathematical Details for meaning.) - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - amari_alpha_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - - Raises: - TypeError: if `alpha` is `None` or a `Tensor`. - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - with ops.name_scope(name, "amari_alpha", [logu]): - if alpha is None or contrib_framework.is_tensor(alpha): - raise TypeError("`alpha` cannot be `None` or `Tensor` type.") - if self_normalized is None or contrib_framework.is_tensor(self_normalized): - raise TypeError("`self_normalized` cannot be `None` or `Tensor` type.") - - logu = ops.convert_to_tensor(logu, name="logu") - - if alpha == 0.: - f = -logu - elif alpha == 1.: - f = math_ops.exp(logu) * logu - else: - f = math_ops.expm1(alpha * logu) / (alpha * (alpha - 1.)) - - if not self_normalized: - return f - - if alpha == 0.: - return f + math_ops.expm1(logu) - elif alpha == 1.: - return f - math_ops.expm1(logu) - else: - return f - math_ops.expm1(logu) / (alpha - 1.) - - -def kl_reverse(logu, self_normalized=False, name=None): - """The reverse Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-reverse Csiszar-function is: - - ```none - f(u) = -log(u) + (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[q, p] - ``` - - The KL is "reverse" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: when self_normalized = True` this function makes non-log-space - calculations and may therefore be numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_reverse_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_reverse", [logu]): - return amari_alpha(logu, alpha=0., self_normalized=self_normalized) - - -def kl_forward(logu, self_normalized=False, name=None): - """The forward Kullback-Leibler Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the KL-forward Csiszar-function is: - - ```none - f(u) = u log(u) - (u - 1) - ``` - - When `self_normalized = False` the `(u - 1)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, q] - ``` - - The KL is "forward" because in maximum likelihood we think of minimizing `q` - as in `KL[p, q]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - kl_forward_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - - Raises: - TypeError: if `self_normalized` is `None` or a `Tensor`. - """ - - with ops.name_scope(name, "kl_forward", [logu]): - return amari_alpha(logu, alpha=1., self_normalized=self_normalized) - - -def jensen_shannon(logu, self_normalized=False, name=None): - """The Jensen-Shannon Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True`, the Jensen-Shannon Csiszar-function is: - - ```none - f(u) = u log(u) - (1 + u) log(1 + u) + (u + 1) log(2) - ``` - - When `self_normalized = False` the `(u + 1) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[p, m] + KL[q, m] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Arithmetic-Geometric - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - For more information, see: - Lin, J. "Divergence measures based on the Shannon entropy." IEEE Trans. - Inf. Th., 37, 145-151, 1991. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jensen_shannon_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jensen_shannon", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - npdt = logu.dtype.as_numpy_dtype - y = nn_ops.softplus(logu) - if self_normalized: - y -= np.log(2).astype(npdt) - return math_ops.exp(logu) * logu - (1. + math_ops.exp(logu)) * y - - -def arithmetic_geometric(logu, self_normalized=False, name=None): - """The Arithmetic-Geometric Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the Arithmetic-Geometric Csiszar-function is: - - ```none - f(u) = (1 + u) log( (1 + u) / sqrt(u) ) - (1 + u) log(2) - ``` - - When `self_normalized = False` the `(1 + u) log(2)` term is omitted. - - Observe that as an f-Divergence, this Csiszar-function implies: - - ```none - D_f[p, q] = KL[m, p] + KL[m, q] - m(x) = 0.5 p(x) + 0.5 q(x) - ``` - - In a sense, this divergence is the "reverse" of the Jensen-Shannon - f-Divergence. - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - arithmetic_geometric_of_u: `float`-like `Tensor` of the - Csiszar-function evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "arithmetic_geometric", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - 0.5 * logu - if self_normalized: - y -= np.log(2.).astype(logu.dtype.as_numpy_dtype) - return (1. + math_ops.exp(logu)) * y - - -def total_variation(logu, name=None): - """The Total Variation Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Total-Variation Csiszar-function is: - - ```none - f(u) = 0.5 |u - 1| - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - total_variation_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "total_variation", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.abs(math_ops.expm1(logu)) - - -def pearson(logu, name=None): - """The Pearson Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Pearson Csiszar-function is: - - ```none - f(u) = (u - 1)**2 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - pearson_of_u: `float`-like `Tensor` of the Csiszar-function evaluated at - `u = exp(logu)`. - """ - - with ops.name_scope(name, "pearson", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.square(math_ops.expm1(logu)) - - -def squared_hellinger(logu, name=None): - """The Squared-Hellinger Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Squared-Hellinger Csiszar-function is: - - ```none - f(u) = (sqrt(u) - 1)**2 - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - squared_hellinger_of_u: `float`-like `Tensor` of the Csiszar-function - evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "squared_hellinger", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(0.5 * logu) - - -def triangular(logu, name=None): - """The Triangular Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Triangular Csiszar-function is: - - ```none - f(u) = (u - 1)**2 / (1 + u) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - triangular_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "triangular", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return pearson(logu) / (1. + math_ops.exp(logu)) - - -def t_power(logu, t, self_normalized=False, name=None): - """The T-Power Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the T-Power Csiszar-function is: - - ```none - f(u) = s [ u**t - 1 - t(u - 1) ] - s = { -1 0 < t < 1 - { +1 otherwise - ``` - - When `self_normalized = False` the `- t(u - 1)` term is omitted. - - This is similar to the `amari_alpha` Csiszar-function, with the associated - divergence being the same up to factors depending only on `t`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - t: `Tensor` of same `dtype` as `logu` and broadcastable shape. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - t_power_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - with ops.name_scope(name, "t_power", [logu, t]): - logu = ops.convert_to_tensor(logu, name="logu") - t = ops.convert_to_tensor(t, dtype=logu.dtype.base_dtype, name="t") - fu = math_ops.expm1(t * logu) - if self_normalized: - fu -= t * math_ops.expm1(logu) - fu *= array_ops.where(math_ops.logical_and(0. < t, t < 1.), - -array_ops.ones_like(t), - array_ops.ones_like(t)) - return fu - - -def log1p_abs(logu, name=None): - """The log1p-abs Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Log1p-Abs Csiszar-function is: - - ```none - f(u) = u**(sign(u-1)) - 1 - ``` - - This function is so-named because it was invented from the following recipe. - Choose a convex function g such that g(0)=0 and solve for f: - - ```none - log(1 + f(u)) = g(log(u)). - <=> - f(u) = exp(g(log(u))) - 1 - ``` - - That is, the graph is identically `g` when y-axis is `log1p`-domain and x-axis - is `log`-domain. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log1p_abs_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "log1p_abs", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(math_ops.abs(logu)) - - -def jeffreys(logu, name=None): - """The Jeffreys Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Jeffreys Csiszar-function is: - - ```none - f(u) = 0.5 ( u log(u) - log(u) ) - = 0.5 kl_forward + 0.5 kl_reverse - = symmetrized_csiszar_function(kl_reverse) - = symmetrized_csiszar_function(kl_forward) - ``` - - This Csiszar-function induces a symmetric f-Divergence, i.e., - `D_f[p, q] = D_f[q, p]`. - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - jeffreys_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "jeffreys", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * math_ops.expm1(logu) * logu - - -def chi_square(logu, name=None): - """The chi-Square Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Chi-square Csiszar-function is: - - ```none - f(u) = u**2 - 1 - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return math_ops.expm1(2. * logu) - - -def modified_gan(logu, self_normalized=False, name=None): - """The Modified-GAN Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - When `self_normalized = True` the modified-GAN (Generative/Adversarial - Network) Csiszar-function is: - - ```none - f(u) = log(1 + u) - log(u) + 0.5 (u - 1) - ``` - - When `self_normalized = False` the `0.5 (u - 1)` is omitted. - - The unmodified GAN Csiszar-function is identical to Jensen-Shannon (with - `self_normalized = False`). - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When - `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even - when `p, q` are unnormalized measures. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - chi_square_of_u: `float`-like `Tensor` of the Csiszar-function evaluated - at `u = exp(logu)`. - """ - - with ops.name_scope(name, "chi_square", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - y = nn_ops.softplus(logu) - logu - if self_normalized: - y += 0.5 * math_ops.expm1(logu) - return y - - -def dual_csiszar_function(logu, csiszar_function, name=None): - """Calculates the dual Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar-dual is defined as: - - ```none - f^*(u) = u f(1 / u) - ``` - - where `f` is some other Csiszar-function. - - For example, the dual of `kl_reverse` is `kl_forward`, i.e., - - ```none - f(u) = -log(u) - f^*(u) = u f(1 / u) = -u log(1 / u) = u log(u) - ``` - - The dual of the dual is the original function: - - ```none - f^**(u) = {u f(1/u)}^*(u) = u (1/u) f(1/(1/u)) = f(u) - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - dual_f_of_u: `float`-like `Tensor` of the result of calculating the dual of - `f` at `u = exp(logu)`. - """ - - with ops.name_scope(name, "dual_csiszar_function", [logu]): - return math_ops.exp(logu) * csiszar_function(-logu) - - -def symmetrized_csiszar_function(logu, csiszar_function, name=None): - """Symmetrizes a Csiszar-function in log-space. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The symmetrized Csiszar-function is defined as: - - ```none - f_g(u) = 0.5 g(u) + 0.5 u g (1 / u) - ``` - - where `g` is some other Csiszar-function. - - We say the function is "symmetrized" because: - - ```none - D_{f_g}[p, q] = D_{f_g}[q, p] - ``` - - for all `p << >> q` (i.e., `support(p) = support(q)`). - - There exists alternatives for symmetrizing a Csiszar-function. For example, - - ```none - f_g(u) = max(f(u), f^*(u)), - ``` - - where `f^*` is the dual Csiszar-function, also implies a symmetric - f-Divergence. - - Example: - - When either of the following functions are symmetrized, we obtain the - Jensen-Shannon Csiszar-function, i.e., - - ```none - g(u) = -log(u) - (1 + u) log((1 + u) / 2) + u - 1 - h(u) = log(4) + 2 u log(u / (1 + u)) - ``` - - implies, - - ```none - f_g(u) = f_h(u) = u log(u) - (1 + u) log((1 + u) / 2) - = jensen_shannon(log(u)). - ``` - - Warning: this function makes non-log-space calculations and may therefore be - numerically unstable for `|logu| >> 0`. - - Args: - logu: `float`-like `Tensor` representing `log(u)` from above. - csiszar_function: Python `callable` representing a Csiszar-function over - log-domain. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - symmetrized_g_of_u: `float`-like `Tensor` of the result of applying the - symmetrization of `g` evaluated at `u = exp(logu)`. - """ - - with ops.name_scope(name, "symmetrized_csiszar_function", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - return 0.5 * (csiszar_function(logu) - + dual_csiszar_function(logu, csiszar_function)) - - -def monte_carlo_csiszar_f_divergence( - f, - p_log_prob, - q, - num_draws, - use_reparametrization=None, - seed=None, - name=None): - """Monte-Carlo approximation of the Csiszar f-Divergence. - - A Csiszar-function is a member of, - - ```none - F = { f:R_+ to R : f convex }. - ``` - - The Csiszar f-Divergence for Csiszar-function f is given by: - - ```none - D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] - ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), - where x_j ~iid q(X) - ``` - - Tricks: Reparameterization and Score-Gradient - - When q is "reparameterized", i.e., a diffeomorphic transformation of a - parameterless distribution (e.g., - `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and - expectation, i.e., - `grad[Avg{ s_i : i=1...n }] = Avg{ grad[s_i] : i=1...n }` where `S_n=Avg{s_i}` - and `s_i = f(x_i), x_i ~iid q(X)`. - - However, if q is not reparameterized, TensorFlow's gradient will be incorrect - since the chain-rule stops at samples of unreparameterized distributions. In - this circumstance using the Score-Gradient trick results in an unbiased - gradient, i.e., - - ```none - grad[ E_q[f(X)] ] - = grad[ int dx q(x) f(x) ] - = int dx grad[ q(x) f(x) ] - = int dx [ q'(x) f(x) + q(x) f'(x) ] - = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] - = int dx q(x) grad[ f(x) q(x) / stop_grad[q(x)] ] - = E_q[ grad[ f(x) q(x) / stop_grad[q(x)] ] ] - ``` - - Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is - usually preferable to set `use_reparametrization = True`. - - Example Application: - - The Csiszar f-Divergence is a useful framework for variational inference. - I.e., observe that, - - ```none - f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] ) - <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ] - := D_f[p(x, Z), q(Z | x)] - ``` - - The inequality follows from the fact that the "perspective" of `f`, i.e., - `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and - `t` is a real. Since the above framework includes the popular Evidence Lower - BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework - "Evidence Divergence Bound Optimization" (EDBO). - - Args: - f: Python `callable` representing a Csiszar-function in log-space, i.e., - takes `p_log_prob(q_samples) - q.log_prob(q_samples)`. - p_log_prob: Python `callable` taking (a batch of) samples from `q` and - returning the natural-log of the probability under distribution `p`. - (In variational inference `p` is the joint distribution.) - q: `tf.Distribution`-like instance; must implement: - `reparameterization_type`, `sample(n, seed)`, and `log_prob(x)`. - (In variational inference `q` is the approximate posterior distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - use_reparametrization: Python `bool`. When `None` (the default), - automatically set to: - `q.reparameterization_type == distribution.FULLY_REPARAMETERIZED`. - When `True` uses the standard Monte-Carlo average. When `False` uses the - score-gradient trick. (See above for details.) When `False`, consider - using `csiszar_vimco`. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - monte_carlo_csiszar_f_divergence: `float`-like `Tensor` Monte Carlo - approximation of the Csiszar f-Divergence. - - Raises: - ValueError: if `q` is not a reparameterized distribution and - `use_reparametrization = True`. A distribution `q` is said to be - "reparameterized" when its samples are generated by transforming the - samples of another distribution which does not depend on the - parameterization of `q`. This property ensures the gradient (with respect - to parameters) is valid. - TypeError: if `p_log_prob` is not a Python `callable`. - """ - with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): - if use_reparametrization is None: - use_reparametrization = (q.reparameterization_type - == distribution.FULLY_REPARAMETERIZED) - elif (use_reparametrization and - q.reparameterization_type != distribution.FULLY_REPARAMETERIZED): - # TODO(jvdillon): Consider only raising an exception if the gradient is - # requested. - raise ValueError( - "Distribution `q` must be reparameterized, i.e., a diffeomorphic " - "transformation of a parameterless distribution. (Otherwise this " - "function has a biased gradient.)") - if not callable(p_log_prob): - raise TypeError("`p_log_prob` must be a Python `callable` function.") - return monte_carlo.expectation( - f=lambda q_samples: f(p_log_prob(q_samples) - q.log_prob(q_samples)), - samples=q.sample(num_draws, seed=seed), - log_prob=q.log_prob, # Only used if use_reparametrization=False. - use_reparametrization=use_reparametrization) - - -def csiszar_vimco(f, - p_log_prob, - q, - num_draws, - num_batch_draws=1, - seed=None, - name=None): - """Use VIMCO to lower the variance of gradient[csiszar_function(Avg(logu))]. - - This function generalizes "Variational Inference for Monte Carlo Objectives" - (VIMCO), i.e., https://arxiv.org/abs/1602.06725, to Csiszar f-Divergences. - - Note: if `q.reparameterization_type = distribution.FULLY_REPARAMETERIZED`, - consider using `monte_carlo_csiszar_f_divergence`. - - The VIMCO loss is: - - ```none - vimco = f(Avg{logu[i] : i=0,...,m-1}) - where, - logu[i] = log( p(x, h[i]) / q(h[i] | x) ) - h[i] iid~ q(H | x) - ``` - - Interestingly, the VIMCO gradient is not the naive gradient of `vimco`. - Rather, it is characterized by: - - ```none - grad[vimco] - variance_reducing_term - where, - variance_reducing_term = Sum{ grad[log q(h[i] | x)] * - (vimco - f(log Avg{h[j;i] : j=0,...,m-1})) - : i=0, ..., m-1 } - h[j;i] = { u[j] j!=i - { GeometricAverage{ u[k] : k!=i} j==i - ``` - - (We omitted `stop_gradient` for brevity. See implementation for more details.) - - The `Avg{h[j;i] : j}` term is a kind of "swap-out average" where the `i`-th - element has been replaced by the leave-`i`-out Geometric-average. - - This implementation prefers numerical precision over efficiency, i.e., - `O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`. - (The constant may be fairly large, perhaps around 12.) - - Args: - f: Python `callable` representing a Csiszar-function in log-space. - p_log_prob: Python `callable` representing the natural-log of the - probability under distribution `p`. (In variational inference `p` is the - joint distribution.) - q: `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and - `log_prob(x)`. (In variational inference `q` is the approximate posterior - distribution.) - num_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - num_batch_draws: Integer scalar number of draws used to approximate the - f-Divergence expectation. - seed: Python `int` seed for `q.sample`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - vimco: The Csiszar f-Divergence generalized VIMCO objective. - - Raises: - ValueError: if `num_draws < 2`. - """ - with ops.name_scope(name, "csiszar_vimco", [num_draws, num_batch_draws]): - if num_draws < 2: - raise ValueError("Must specify num_draws > 1.") - stop = array_ops.stop_gradient # For readability. - x = stop(q.sample(sample_shape=[num_draws, num_batch_draws], - seed=seed)) - logqx = q.log_prob(x) - logu = p_log_prob(x) - logqx - f_log_avg_u, f_log_sooavg_u = [f(r) for r in csiszar_vimco_helper(logu)] - dotprod = math_ops.reduce_sum( - logqx * stop(f_log_avg_u - f_log_sooavg_u), - axis=0) # Sum over iid samples. - # We now rewrite f_log_avg_u so that: - # `grad[f_log_avg_u] := grad[f_log_avg_u + dotprod]`. - # To achieve this, we use a trick that - # `f(x) - stop(f(x)) == zeros_like(f(x))` - # but its gradient is grad[f(x)]. - # Note that IEEE754 specifies that `x - x == 0.` and `x + 0. == x`, hence - # this trick loses no precision. For more discussion regarding the relevant - # portions of the IEEE754 standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - f_log_avg_u += dotprod - stop(dotprod) # Add zeros_like(dot_prod). - return math_ops.reduce_mean(f_log_avg_u, axis=0) # Avg over batches. - - -def csiszar_vimco_helper(logu, name=None): - """Helper to `csiszar_vimco`; computes `log_avg_u`, `log_sooavg_u`. - - `axis = 0` of `logu` is presumed to correspond to iid samples from `q`, i.e., - - ```none - logu[j] = log(u[j]) - u[j] = p(x, h[j]) / q(h[j] | x) - h[j] iid~ q(H | x) - ``` - - Args: - logu: Floating-type `Tensor` representing `log(p(x, h) / q(h | x))`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - log_avg_u: `logu.dtype` `Tensor` corresponding to the natural-log of the - average of `u`. The sum of the gradient of `log_avg_u` is `1`. - log_sooavg_u: `logu.dtype` `Tensor` characterized by the natural-log of the - average of `u`` except that the average swaps-out `u[i]` for the - leave-`i`-out Geometric-average. The mean of the gradient of - `log_sooavg_u` is `1`. Mathematically `log_sooavg_u` is, - ```none - log_sooavg_u[i] = log(Avg{h[j ; i] : j=0, ..., m-1}) - h[j ; i] = { u[j] j!=i - { GeometricAverage{u[k] : k != i} j==i - ``` - - """ - with ops.name_scope(name, "csiszar_vimco_helper", [logu]): - logu = ops.convert_to_tensor(logu, name="logu") - - n = logu.shape.with_rank_at_least(1)[0].value - if n is None: - n = array_ops.shape(logu)[0] - log_n = math_ops.log(math_ops.cast(n, dtype=logu.dtype)) - nm1 = math_ops.cast(n - 1, dtype=logu.dtype) - else: - log_n = np.log(n).astype(logu.dtype.as_numpy_dtype) - nm1 = np.asarray(n - 1, dtype=logu.dtype.as_numpy_dtype) - - # Throughout we reduce across axis=0 since this is presumed to be iid - # samples. - - log_max_u = math_ops.reduce_max(logu, axis=0) - log_sum_u_minus_log_max_u = math_ops.reduce_logsumexp( - logu - log_max_u, axis=0) - - # log_loosum_u[i] = - # = logsumexp(logu[j] : j != i) - # = log( exp(logsumexp(logu)) - exp(logu[i]) ) - # = log( exp(logsumexp(logu - logu[i])) exp(logu[i]) - exp(logu[i])) - # = logu[i] + log(exp(logsumexp(logu - logu[i])) - 1) - # = logu[i] + log(exp(logsumexp(logu) - logu[i]) - 1) - # = logu[i] + softplus_inverse(logsumexp(logu) - logu[i]) - d = log_sum_u_minus_log_max_u + (log_max_u - logu) - # We use `d != 0` rather than `d > 0.` because `d < 0.` should never - # happens; if it does we want to complain loudly (which `softplus_inverse` - # will). - d_ok = math_ops.not_equal(d, 0.) - safe_d = array_ops.where(d_ok, d, array_ops.ones_like(d)) - d_ok_result = logu + distribution_util.softplus_inverse(safe_d) - - inf = np.array(np.inf, dtype=logu.dtype.as_numpy_dtype) - - # When not(d_ok) and is_positive_and_largest then we manually compute the - # log_loosum_u. (We can efficiently do this for any one point but not all, - # hence we still need the above calculation.) This is good because when - # this condition is met, we cannot use the above calculation; its -inf. - # We now compute the log-leave-out-max-sum, replicate it to every - # point and make sure to select it only when we need to. - is_positive_and_largest = math_ops.logical_and( - logu > 0., - math_ops.equal(logu, log_max_u[array_ops.newaxis, ...])) - log_lomsum_u = math_ops.reduce_logsumexp( - array_ops.where(is_positive_and_largest, - array_ops.fill(array_ops.shape(logu), -inf), - logu), - axis=0, keep_dims=True) - log_lomsum_u = array_ops.tile( - log_lomsum_u, - multiples=1 + array_ops.pad([n-1], [[0, array_ops.rank(logu)-1]])) - - d_not_ok_result = array_ops.where( - is_positive_and_largest, - log_lomsum_u, - array_ops.fill(array_ops.shape(d), -inf)) - - log_loosum_u = array_ops.where(d_ok, d_ok_result, d_not_ok_result) - - # The swap-one-out-sum ("soosum") is n different sums, each of which - # replaces the i-th item with the i-th-left-out average, i.e., - # soo_sum_u[i] = [exp(logu) - exp(logu[i])] + exp(mean(logu[!=i])) - # = exp(log_loosum_u[i]) + exp(looavg_logu[i]) - looavg_logu = (math_ops.reduce_sum(logu, axis=0) - logu) / nm1 - log_soosum_u = math_ops.reduce_logsumexp( - array_ops.stack([log_loosum_u, looavg_logu]), - axis=0) - - log_avg_u = log_sum_u_minus_log_max_u + log_max_u - log_n - log_sooavg_u = log_soosum_u - log_n - - log_avg_u.set_shape(logu.shape.with_rank_at_least(1)[1:]) - log_sooavg_u.set_shape(logu.shape) - - return log_avg_u, log_sooavg_u diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py deleted file mode 100644 index d44fe6529a7ff0da0c6747e193fdb98a272a8da3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions for specifying custom gradients. - -@@custom_gradient - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - -__all__ = [ - "custom_gradient", -] - - -def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, - name=None): - """Enables specifying a custom gradient. - - This function works by clever application of `stop_gradient`. I.e., observe - that: - - ```none - h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) - ``` - - is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = - stop_gradient(g(x)).` - - In addition to scalar-domain/scalar-range functions, this function also - supports tensor-domain/scalar-range functions. However, in the latter case it - is necessary to reduce `x` to a scalar. This can be done by indicating the - `axis` over which `f` operates or by appropriately `reduce_sum`-ing `x`, prior - to calling this function. - - Partial Custom Gradient: - - Suppose `h(x) = htilde(x, y)`. Note that `dh/dx = stop(g(x))` but `dh/dy = - None`. This is because a `Tensor` cannot have only a portion of its gradient - stopped. To circumvent this issue, one must manually `stop_gradient` the - relevant portions of `f`, `g`. For example see the unit-test, - `test_works_correctly_fx_gx_manually_stopped`. - - Args: - fx: `Tensor`. Output of function evaluated at `x`. - gx: `Tensor`. Gradient of function evaluated at `x`. - x: `Tensor`. Point of evaluation for `f, g`. - axis: 1D `int` `Tensor` representing dimensions of `x` which are the domain - of `f`. If `()` (the default), `f` is assumed scalar-domain/scalar-range. - If `None` `f` is assumed to render one scalar given all of `x`. Otherwise - `f` is assumed to output one scalar for each of `axis` dimensions of `x`. - fx_gx_manually_stopped: Python `bool` indicating that `fx`, `gx` manually - have `stop_gradient` applied. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - fx: Floating-type `Tensor` equal to `f(x)` but which has gradient - `stop_gradient(g(x))`. - """ - with ops.name_scope(name, "custom_gradient", [fx, gx, x]): - fx = ops.convert_to_tensor(fx, name="fx") - # We don't want to bother eagerly computing `gx` since we may not even need - # it. - with ops.control_dependencies([fx]): - gx = ops.convert_to_tensor(gx, dtype=fx.dtype, name="gx") - gx = array_ops.identity(gx, name="gx") - # Proof of correctness: - # - # f(x) = x * stop[gx] + stop[fx - x * gx] - # = stop[fx] - # - # g(x) = grad[fx] - # = stop[gx] + grad[stop[fx - x * gx]] - # = stop[gx] + 0 - # - # Notice that when x is zero it still works: - # grad[x * stop(gx) + stop(fx - x * gx)] = 1 * stop[gx] + 0 = stop[gx] - # - # The proof is similar for the tensor-domain case, except that `x` is - # replaced by `reduce_sum(x)`. - sum_x = math_ops.reduce_sum(x, axis=axis, name="sum_x") - if not fx_gx_manually_stopped: - fx = array_ops.stop_gradient(fx) - gx = array_ops.stop_gradient(gx) - # IEEE754 ensures `(x-x)==0.` and that `0.*x==0.` so we make sure to write - # the code this way, rather than, e.g., - # `sum_x * stop(gx) + stop(fx - sum_x * gx)`. - # For more discussion regarding the relevant portions of the IEEE754 - # standard, see the StackOverflow question, - # "Is there a floating point value of x, for which x-x == 0 is false?" - # http://stackoverflow.com/q/2686644 - return (sum_x - array_ops.stop_gradient(sum_x)) * gx + fx diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py deleted file mode 100644 index 8cabf18903b5f15002470acdfb8fdd3ec31a7413..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Quasi Monte Carlo support: Halton sequence. - -@@sample -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - - -__all__ = [ - 'sample', -] - - -# The maximum dimension we support. This is limited by the number of primes -# in the _PRIMES array. -_MAX_DIMENSION = 1000 - - -def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): - r"""Returns a sample from the `m` dimensional Halton sequence. - - Warning: The sequence elements take values only between 0 and 1. Care must be - taken to appropriately transform the domain of a function if it differs from - the unit cube before evaluating integrals using Halton samples. It is also - important to remember that quasi-random numbers are not a replacement for - pseudo-random numbers in every context. Quasi random numbers are completely - deterministic and typically have significant negative autocorrelation (unless - randomized). - - Computes the members of the low discrepancy Halton sequence in dimension - `dim`. The d-dimensional sequence takes values in the unit hypercube in d - dimensions. Currently, only dimensions up to 1000 are supported. The prime - base for the `k`-th axes is the k-th prime starting from 2. For example, - if dim = 3, then the bases will be [2, 3, 5] respectively and the first - element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete - description of the Halton sequences see: - https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences - and their applications see: - https://en.wikipedia.org/wiki/Low-discrepancy_sequence. - - The user must supply either `num_samples` or `sample_indices` but not both. - The former is the number of samples to produce starting from the first - element. If `sample_indices` is given instead, the specified elements of - the sequence are generated. For example, sample_indices=tf.range(10) is - equivalent to specifying n=10. - - Example Use: - - ```python - bf = tf.contrib.bayesflow - - # Produce the first 1000 members of the Halton sequence in 3 dimensions. - num_samples = 1000 - dim = 3 - sample = bf.halton_sequence.sample(dim, num_samples=num_samples) - - # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional - # hypercube. - powers = tf.range(1.0, limit=dim + 1) - integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) - true_value = 1.0 / tf.reduce_prod(powers + 1.0) - with tf.Session() as session: - values = session.run((integral, true_value)) - - # Produces a relative absolute error of 1.7%. - print ("Estimated: %f, True Value: %f" % values) - - # Now skip the first 1000 samples and recompute the integral with the next - # thousand samples. The sample_indices argument can be used to do this. - - - sample_indices = tf.range(start=1000, limit=1000 + num_samples, - dtype=tf.int32) - sample_leaped = halton.sample(dim, sample_indices=sample_indices) - - integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, - axis=-1)) - with tf.Session() as session: - values = session.run((integral_leaped, true_value)) - # Now produces a relative absolute error of 0.05%. - print ("Leaped Estimated: %f, True Value: %f" % values) - ``` - - Args: - dim: Positive Python `int` representing each sample's `event_size.` Must - not be greater than 1000. - num_samples: (Optional) positive Python `int`. The number of samples to - generate. Either this parameter or sample_indices must be specified but - not both. If this parameter is None, then the behaviour is determined by - the `sample_indices`. - sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements - of the sequence to compute specified by their position in the sequence. - The entries index into the Halton sequence starting with 0 and hence, - must be whole numbers. For example, sample_indices=[0, 5, 6] will produce - the first, sixth and seventh elements of the sequence. If this parameter - is None, then the `num_samples` parameter must be specified which gives - the number of desired samples starting from the first sample. - dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. - Default is `float32`. - name: (Optional) Python `str` describing ops managed by this function. If - not supplied the name of this function is used. - - Returns: - halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype - and `shape` `[num_samples, dim]` if `num_samples` was specified or shape - `[s, dim]` where s is the size of `sample_indices` if `sample_indices` - were specified. - - Raises: - ValueError: if both `sample_indices` and `num_samples` were specified or - if dimension `dim` is less than 1 or greater than 1000. - """ - if dim < 1 or dim > _MAX_DIMENSION: - raise ValueError( - 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, - dim)) - if (num_samples is None) == (sample_indices is None): - raise ValueError('Either `num_samples` or `sample_indices` must be' - ' specified but not both.') - - dtype = dtype or dtypes.float32 - if not dtype.is_floating: - raise ValueError('dtype must be of `float`-type') - - with ops.name_scope(name, 'sample', values=[sample_indices]): - # Here and in the following, the shape layout is as follows: - # [sample dimension, event dimension, coefficient dimension]. - # The coefficient dimension is an intermediate axes which will hold the - # weights of the starting integer when expressed in the (prime) base for - # an event dimension. - indices = _get_indices(num_samples, sample_indices, dtype) - radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) - - max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), - radixes) - - max_size = math_ops.reduce_max(max_sizes_by_axes) - - # The powers of the radixes that we will need. Note that there is a bit - # of an excess here. Suppose we need the place value coefficients of 7 - # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits - # for base 3. However, we can only create rectangular tensors so we - # store both expansions in a [2, 3] tensor. This leads to the problem that - # we might end up attempting to raise large numbers to large powers. For - # example, base 2 expansion of 1024 has 10 digits. If we were in 10 - # dimensions, then the 10th prime (29) we will end up computing 29^10 even - # though we don't need it. We avoid this by setting the exponents for each - # axes to 0 beyond the maximum value needed for that dimension. - exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1]) - weight_mask = exponents_by_axes > max_sizes_by_axes - capped_exponents = array_ops.where( - weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes) - weights = radixes ** capped_exponents - coeffs = math_ops.floor_div(indices, weights) - coeffs *= 1 - math_ops.cast(weight_mask, dtype) - coeffs = (coeffs % radixes) / radixes - return math_ops.reduce_sum(coeffs / weights, axis=-1) - - -def _get_indices(n, sample_indices, dtype, name=None): - """Generates starting points for the Halton sequence procedure. - - The k'th element of the sequence is generated starting from a positive integer - which must be distinct for each `k`. It is conventional to choose the starting - point as `k` itself (or `k+1` if k is zero based). This function generates - the starting integers for the required elements and reshapes the result for - later use. - - Args: - n: Positive `int`. The number of samples to generate. If this - parameter is supplied, then `sample_indices` should be None. - sample_indices: `Tensor` of dtype int32 and rank 1. The entries - index into the Halton sequence starting with 0 and hence, must be whole - numbers. For example, sample_indices=[0, 5, 6] will produce the first, - sixth and seventh elements of the sequence. If this parameter is not None - then `n` must be None. - dtype: The dtype of the sample. One of `float32` or `float64`. - Default is `float32`. - name: Python `str` name which describes ops created by this function. - - Returns: - indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. - """ - with ops.name_scope(name, 'get_indices', [n, sample_indices]): - if sample_indices is None: - sample_indices = math_ops.range(n, dtype=dtype) - else: - sample_indices = math_ops.cast(sample_indices, dtype) - - # Shift the indices so they are 1 based. - indices = sample_indices + 1 - - # Reshape to make space for the event dimension and the place value - # coefficients. - return array_ops.reshape(indices, [-1, 1, 1]) - - -def _base_expansion_size(num, bases): - """Computes the number of terms in the place value expansion. - - Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of - `num` in base b (ak <> 0). This function computes and returns `k` for each - base `b` specified in `bases`. - - This can be inferred from the base `b` logarithm of `num` as follows: - $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ - - Args: - num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to - compute the base expansion size of. - bases: `Tensor` of the same dtype as num. The bases to compute the size - against. - - Returns: - Tensor of same dtype and shape as `bases` containing the size of num when - written in that base. - """ - return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1 - - -def _primes_less_than(n): - # Based on - # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 - """Returns sorted array of primes such that `2 <= prime < n`.""" - small_primes = np.array((2, 3, 5)) - if n <= 6: - return small_primes[small_primes < n] - sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool) - sieve[0] = False - m = int(n ** 0.5) // 3 + 1 - for i in range(m): - if not sieve[i]: - continue - k = 3 * i + 1 | 1 - sieve[k ** 2 // 3::2 * k] = False - sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False - return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] - -_PRIMES = _primes_less_than(7919+1) - -assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py deleted file mode 100644 index 7fd5652c5c3e085b23c05baef6e3a42b7a42e08f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member -from tensorflow.python.util import all_util - -_allowed_symbols = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", -] - -all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py deleted file mode 100644 index f724910c59315867a42a56fab3deb36f5d3adb7a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ /dev/null @@ -1,1185 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. - -@@sample_chain -@@sample_annealed_importance_chain -@@kernel -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import util as distributions_util - -__all__ = [ - "sample_chain", - "sample_annealed_importance_chain", - "kernel", -] - - -KernelResults = collections.namedtuple( - "KernelResults", - [ - "acceptance_probs", - "current_grads_target_log_prob", # "Current result" means "accepted". - "current_target_log_prob", # "Current result" means "accepted". - "energy_change", - "is_accepted", - "proposed_grads_target_log_prob", - "proposed_state", - "proposed_target_log_prob", - "random_positive", - ]) - - -def _make_dummy_kernel_results( - dummy_state, - dummy_target_log_prob, - dummy_grads_target_log_prob): - return KernelResults( - acceptance_probs=dummy_target_log_prob, - current_grads_target_log_prob=dummy_grads_target_log_prob, - current_target_log_prob=dummy_target_log_prob, - energy_change=dummy_target_log_prob, - is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), - proposed_grads_target_log_prob=dummy_grads_target_log_prob, - proposed_state=dummy_state, - proposed_target_log_prob=dummy_target_log_prob, - random_positive=dummy_target_log_prob, - ) - - -def sample_chain( - num_results, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - num_burnin_steps=0, - num_steps_between_results=0, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm - that takes a series of gradient-informed steps to produce a Metropolis - proposal. This function samples from an HMC Markov chain at `current_state` - and whose stationary distribution has log-unnormalized-density - `target_log_prob_fn()`. - - This function samples from multiple chains in parallel. It assumes that the - the leftmost dimensions of (each) `current_state` (part) index an independent - chain. The function `target_log_prob_fn()` sums log-probabilities across - event dimensions (i.e., current state (part) rightmost dimensions). Each - element of the output of `target_log_prob_fn()` represents the (possibly - unnormalized) log-probability of the joint distribution over (all) the current - state (parts). - - The `current_state` can be represented as a single `Tensor` or a `list` of - `Tensors` which collectively represent the current state. When specifying a - `list`, one must also specify a list of `step_size`s. - - Note: `target_log_prob_fn` is called exactly twice. - - Only one out of every `num_steps_between_samples + 1` steps is included in the - returned results. This "thinning" comes at a cost of reduced statistical - power, while reducing memory requirements and autocorrelation. For more - discussion see [1]. - - [1]: "Statistically efficient thinning of a Markov chain sampler." - Art B. Owen. April 2017. - http://statweb.stanford.edu/~owen/reports/bestthinning.pdf - - #### Examples: - - ##### Sample from a diagonal-variance Gaussian. - - ```python - tfd = tf.contrib.distributions - - def make_likelihood(true_variances): - return tfd.MultivariateNormalDiag( - scale_diag=tf.sqrt(true_variances)) - - dims = 10 - dtype = np.float32 - true_variances = tf.linspace(dtype(1), dtype(3), dims) - likelihood = make_likelihood(true_variances) - - states, kernel_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=likelihood.log_prob, - current_state=tf.zeros(dims), - step_size=0.5, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(states, axis=0) - sample_var = tf.reduce_mean( - tf.squared_difference(states, sample_mean), - axis=0) - ``` - - ##### Sampling from factor-analysis posteriors with known factors. - - I.e., - - ```none - for i=1..n: - w[i] ~ Normal(0, eye(d)) # prior - x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood - ``` - - where `F` denotes factors. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, factors): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) - - # Setup data. - num_weights = 10 - num_factors = 4 - num_chains = 100 - dtype = np.float32 - - prior = make_prior(num_weights, dtype) - weights = prior.sample(num_chains) - factors = np.random.randn(num_factors, num_weights).astype(dtype) - x = make_likelihood(weights, factors).sample(num_chains) - - def target_log_prob(w): - # Target joint is: `f(w) = p(w, x | factors)`. - return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) - - # Get `num_results` samples from `num_chains` independent chains. - chains_states, kernels_results = hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target_log_prob, - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2, - num_burnin_steps=500) - - # Compute sample stats. - sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) - sample_var = tf.reduce_mean( - tf.squared_difference(chains_states, sample_mean), - axis=[0, 1]) - ``` - - Args: - num_results: Integer number of Markov chain draws. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - num_burnin_steps: Integer number of chain steps to take before starting to - collect results. - Default value: 0 (i.e., no burn-in). - num_steps_between_results: Integer number of chain steps between collecting - a result. Only one out of every `num_steps_between_samples + 1` steps is - included in the returned results. This "thinning" comes at a cost of - reduced statistical power, while reducing memory requirements and - autocorrelation. For more discussion see [1]. - Default value: 0 (i.e., no subsampling). - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob` at the `current_state` and wrt - the `current_state`. Must have same shape as `current_state`. The only - reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_chain"). - - Returns: - accepted_states: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state` but with a prepended `num_results`-size dimension. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - with ops.name_scope( - name, "hmc_sample_chain", - [num_results, current_state, step_size, num_leapfrog_steps, - num_burnin_steps, num_steps_between_results, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob, - ] = _prepare_args( - target_log_prob_fn, - current_state, - step_size, - current_target_log_prob, - current_grads_target_log_prob) - num_results = ops.convert_to_tensor( - num_results, - dtype=dtypes.int32, - name="num_results") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - num_burnin_steps = ops.convert_to_tensor( - num_burnin_steps, - dtype=dtypes.int32, - name="num_burnin_steps") - num_steps_between_results = ops.convert_to_tensor( - num_steps_between_results, - dtype=dtypes.int32, - name="num_steps_between_results") - - def _run_chain(num_steps, current_state, kernel_results): - """Runs the chain(s) for `num_steps`.""" - def _loop_body(iter_, current_state, kernel_results): - return [iter_ + 1] + list(kernel( - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), - current_state, - kernel_results, - ], - ) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - return control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - def _scan_body(args_list, iter_): - """Closure which implements `tf.scan` body.""" - current_state, kernel_results = args_list - return _run_chain( - 1 + array_ops.where(math_ops.equal(iter_, 0), - num_burnin_steps, - num_steps_between_results), - current_state, - kernel_results) - - scan_kwargs = dict( - fn=_scan_body, - elems=math_ops.range(num_results), # iter_: used to choose burnin. - initializer=[ - current_state, - _make_dummy_kernel_results( - current_state, - current_target_log_prob, - current_grads_target_log_prob), - ]) - if seed is not None: - scan_kwargs["parallel_iterations"] = 1 - return functional_ops.scan(**scan_kwargs) - - -def sample_annealed_importance_chain( - proposal_log_prob_fn, - num_steps, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - name=None): - """Runs annealed importance sampling (AIS) to estimate normalizing constants. - - This function uses Hamiltonian Monte Carlo to sample from a series of - distributions that slowly interpolates between an initial "proposal" - distribution: - - `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - - and the target distribution: - - `exp(target_log_prob_fn(x) - target_log_normalizer)`, - - accumulating importance weights along the way. The product of these - importance weights gives an unbiased estimate of the ratio of the - normalizing constants of the initial distribution and the target - distribution: - - `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. - - Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three - times (although this may be reduced to two times, in the future). - - #### Examples: - - ##### Estimate the normalizing constant of a log-gamma distribution. - - ```python - tfd = tf.contrib.distributions - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 20 - dtype = np.float32 - - proposal = tfd.MultivatiateNormalDiag( - loc=tf.zeros([dims], dtype=dtype)) - - target = tfd.TransformedDistribution( - distribution=tfd.Gamma(concentration=dtype(2), - rate=dtype(3)), - bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), - event_shape=[dims]) - - chains_state, ais_weights, kernels_results = ( - hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal.log_prob, - num_steps=1000, - target_log_prob_fn=target.log_prob, - step_size=0.2, - current_state=proposal.sample(num_chains), - num_leapfrog_steps=2)) - - log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) - ``` - - ##### Estimate marginal likelihood of a Bayesian regression model. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, x): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, x, axes=[[0], [-1]])) - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 10 - dtype = np.float32 - - # Make training data. - x = np.random.randn(num_chains, dims).astype(dtype) - true_weights = np.random.randn(dims).astype(dtype) - y = np.dot(x, true_weights) + np.random.randn(num_chains) - - # Setup model. - prior = make_prior(dims, dtype) - def target_log_prob_fn(weights): - return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) - - proposal = tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - weight_samples, ais_weights, kernel_results = ( - hmc.sample_annealed_importance_chain( - num_steps=1000, - proposal_log_prob_fn=proposal.log_prob, - target_log_prob_fn=target_log_prob_fn - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2)) - log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - ``` - - Args: - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - num_steps: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). - - Returns: - accepted_state: `Tensor` or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at the final iteration. Has same shape as - input `current_state`. - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(current_state)`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - def make_convex_combined_log_prob_fn(iter_): - def _fn(*args): - p = proposal_log_prob_fn(*args) - t = target_log_prob_fn(*args) - dtype = p.dtype.base_dtype - beta = (math_ops.cast(iter_ + 1, dtype) - / math_ops.cast(num_steps, dtype)) - return (1. - beta) * p + beta * t - return _fn - - with ops.name_scope( - name, "hmc_sample_annealed_importance_chain", - [num_steps, current_state, step_size, num_leapfrog_steps, seed]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_log_prob, - current_grads_log_prob, - ] = _prepare_args( - make_convex_combined_log_prob_fn(iter_=0), - current_state, - step_size, - description="convex_combined_log_prob") - num_steps = ops.convert_to_tensor( - num_steps, - dtype=dtypes.int32, - name="num_steps") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - def _loop_body(iter_, ais_weights, current_state, kernel_results): - """Closure which implements `tf.while_loop` body.""" - current_state_parts = (list(current_state) - if _is_list_like(current_state) - else [current_state]) - # TODO(b/72994218): Consider refactoring things to avoid this unecessary - # call. - ais_weights += ((target_log_prob_fn(*current_state_parts) - - proposal_log_prob_fn(*current_state_parts)) - / math_ops.cast(num_steps, ais_weights.dtype)) - return [iter_ + 1, ais_weights] + list(kernel( - make_convex_combined_log_prob_fn(iter_), - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - array_ops.zeros_like(current_log_prob), # ais_weights - current_state, - _make_dummy_kernel_results(current_state, - current_log_prob, - current_grads_log_prob), - ]) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - - [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - return [current_state, ais_weights, kernel_results] - - -def kernel(target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Runs one iteration of Hamiltonian Monte Carlo. - - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function applies one step of HMC to - randomly update the variable `x`. - - This function can update multiple chains in parallel. It assumes that all - leftmost dimensions of `current_state` index independent chain states (and are - therefore updated independently). The output of `target_log_prob_fn()` should - sum log-probabilities across all event dimensions. Slices along the rightmost - dimensions may have different target distributions; for example, - `current_state[0, :]` could have a different target distribution from - `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of - independent chains is `tf.size(target_log_prob_fn(*current_state))`.) - - #### Examples: - - ##### Simple chain with warm-up. - - ```python - tfd = tf.contrib.distributions - - # Tuning acceptance rates: - dtype = np.float32 - target_accept_rate = 0.631 - num_warmup_iter = 500 - num_chain_iter = 500 - - x = tf.get_variable(name="x", initializer=dtype(1)) - step_size = tf.get_variable(name="step_size", initializer=dtype(1)) - - target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - - new_x, other_results = hmc.kernel( - target_log_prob_fn=target.log_prob, - current_state=x, - step_size=step_size, - num_leapfrog_steps=3)[:4] - - x_update = x.assign(new_x) - - step_size_update = step_size.assign_add( - step_size * tf.where( - other_results.acceptance_probs > target_accept_rate, - 0.01, -0.01)) - - warmup = tf.group([x_update, step_size_update]) - - tf.global_variables_initializer().run() - - sess.graph.finalize() # No more graph building. - - # Warm up the sampler and adapt the step size - for _ in xrange(num_warmup_iter): - sess.run(warmup) - - # Collect samples without adapting step size - samples = np.zeros([num_chain_iter]) - for i in xrange(num_chain_iter): - _, x_, target_log_prob_, grad_ = sess.run([ - x_update, - x, - other_results.target_log_prob, - other_results.grads_target_log_prob]) - samples[i] = x_ - - print(samples.mean(), samples.std()) - ``` - - ##### Sample from more complicated posterior. - - I.e., - - ```none - W ~ MVN(loc=0, scale=sigma * eye(dims)) - for i=1...num_samples: - X[i] ~ MVN(loc=0, scale=eye(dims)) - eps[i] ~ Normal(loc=0, scale=1) - Y[i] = X[i].T * W + eps[i] - ``` - - ```python - tfd = tf.contrib.distributions - - def make_training_data(num_samples, dims, sigma): - dt = np.asarray(sigma).dtype - zeros = tf.zeros(dims, dtype=dt) - x = tfd.MultivariateNormalDiag( - loc=zeros).sample(num_samples, seed=1) - w = tfd.MultivariateNormalDiag( - loc=zeros, - scale_identity_multiplier=sigma).sample(seed=2) - noise = tfd.Normal( - loc=dt(0), - scale=dt(1)).sample(num_samples, seed=3) - y = tf.tensordot(x, w, axes=[[1], [0]]) + noise - return y, x, w - - def make_prior(sigma, dims): - # p(w | sigma) - return tfd.MultivariateNormalDiag( - loc=tf.zeros([dims], dtype=sigma.dtype), - scale_identity_multiplier=sigma) - - def make_likelihood(x, w): - # p(y | x, w) - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(x, w, axes=[[1], [0]])) - - # Setup assumptions. - dtype = np.float32 - num_samples = 150 - dims = 10 - num_iters = int(5e3) - - true_sigma = dtype(0.5) - y, x, true_weights = make_training_data(num_samples, dims, true_sigma) - - # Estimate of `log(true_sigma)`. - log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) - sigma = tf.exp(log_sigma) - - # State of the Markov chain. - weights = tf.get_variable( - name="weights", - initializer=np.random.randn(dims).astype(dtype)) - - prior = make_prior(sigma, dims) - - def joint_log_prob_fn(w): - # f(w) = log p(w, y | x) - return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) - - weights_update = weights.assign( - hmc.kernel(target_log_prob_fn=joint_log_prob, - current_state=weights, - step_size=0.1, - num_leapfrog_steps=5)[0]) - - with tf.control_dependencies([weights_update]): - loss = -prior.log_prob(weights) - - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) - - sess.graph.finalize() # No more graph building. - - tf.global_variables_initializer().run() - - sigma_history = np.zeros(num_iters, dtype) - weights_history = np.zeros([num_iters, dims], dtype) - - for i in xrange(num_iters): - _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) - weights_history[i, :] = weights_ - sigma_history[i] = sigma_ - - true_weights_ = sess.run(true_weights) - - # Should converge to something close to true_sigma. - plt.plot(sigma_history); - plt.ylabel("sigma"); - plt.xlabel("iteration"); - ``` - - Args: - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn` at the `current_state`. The only reason to - specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `current_target_log_prob` at the `current_state` - and wrt the `current_state`. Must have same shape as `current_state`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_kernel"). - - Returns: - accepted_state: Tensor or Python list of `Tensor`s representing the state(s) - of the Markov chain(s) at each result step. Has same shape as - `current_state`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - - Raises: - ValueError: if there isn't one `step_size` or a list with same length as - `current_state`. - """ - with ops.name_scope( - name, "hmc_kernel", - [current_state, step_size, num_leapfrog_steps, seed, - current_target_log_prob, current_grads_target_log_prob]): - with ops.name_scope("initialize"): - [current_state_parts, step_sizes, current_target_log_prob, - current_grads_target_log_prob] = _prepare_args( - target_log_prob_fn, current_state, step_size, - current_target_log_prob, current_grads_target_log_prob, - maybe_expand=True) - independent_chain_ndims = distributions_util.prefer_static_rank( - current_target_log_prob) - current_momentums = [] - for s in current_state_parts: - current_momentums.append(random_ops.random_normal( - shape=array_ops.shape(s), - dtype=s.dtype.base_dtype, - seed=seed)) - seed = distributions_util.gen_new_seed( - seed, salt="hmc_kernel_momentums") - - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] = _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob, - current_grads_target_log_prob) - - energy_change = _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims) - - # u < exp(min(-energy, 0)), where u~Uniform[0,1) - # ==> -log(u) >= max(e, 0) - # ==> -log(u) >= e - # (Perhaps surprisingly, we don't have a better way to obtain a random - # uniform from positive reals, i.e., `tf.random_uniform(minval=0, - # maxval=np.inf)` won't work.) - random_uniform = random_ops.random_uniform( - shape=array_ops.shape(energy_change), - dtype=energy_change.dtype, - seed=seed) - random_positive = -math_ops.log(random_uniform) - is_accepted = random_positive >= energy_change - - accepted_target_log_prob = array_ops.where(is_accepted, - proposed_target_log_prob, - current_target_log_prob) - - accepted_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] - - accepted_grads_target_log_prob = [ - _choose(is_accepted, - proposed_grad, - grad, - independent_chain_ndims) - for proposed_grad, grad - in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] - - maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] - return [ - maybe_flatten(accepted_state_parts), - KernelResults( - acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), - current_grads_target_log_prob=accepted_grads_target_log_prob, - current_target_log_prob=accepted_target_log_prob, - energy_change=energy_change, - is_accepted=is_accepted, - proposed_grads_target_log_prob=proposed_grads_target_log_prob, - proposed_state=maybe_flatten(proposed_state_parts), - proposed_target_log_prob=proposed_target_log_prob, - random_positive=random_positive, - ), - ] - - -def _leapfrog_integrator(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - num_leapfrog_steps, - current_target_log_prob=None, - current_grads_target_log_prob=None, - name=None): - """Applies `num_leapfrog_steps` of the leapfrog integrator. - - Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. - - #### Examples: - - ##### Simple quadratic potential. - - ```python - tfd = tf.contrib.distributions - - dims = 10 - num_iter = int(1e3) - dtype = np.float32 - - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - - [ - new_momentums, - new_positions, - ] = hmc._leapfrog_integrator( - current_momentums=[momentum], - target_log_prob_fn=tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)).log_prob, - current_state_parts=[position], - step_sizes=0.1, - num_leapfrog_steps=3)[:2] - - sess.graph.finalize() # No more graph building. - - momentum_ = np.random.randn(dims).astype(dtype) - position_ = np.random.randn(dims).astype(dtype) - - positions = np.zeros([num_iter, dims], dtype) - for i in xrange(num_iter): - position_, momentum_ = sess.run( - [new_momentums[0], new_position[0]], - feed_dict={position: position_, momentum: momentum_}) - positions[i] = position_ - - plt.plot(positions[:, 0]); # Sinusoidal. - ``` - - Args: - current_momentums: Tensor containing the value(s) of the momentum - variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like - `*current_state_parts` and returns its (possibly unnormalized) log-density - under the target distribution. - current_state_parts: Python `list` of `Tensor`s representing the current - state(s) of the Markov chain(s). The first `independent_chain_ndims` of - the `Tensor`(s) index different chains. - step_sizes: Python `list` of `Tensor`s representing the step size for the - leapfrog integrator. Must broadcast with the shape of - `current_state_parts`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. When - possible, it's often helpful to match per-variable step sizes to the - standard deviations of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - current_target_log_prob: (Optional) `Tensor` representing the value of - `target_log_prob_fn(*current_state_parts)`. The only reason to specify - this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - current_grads_target_log_prob: (Optional) Python list of `Tensor`s - representing gradient of `target_log_prob_fn(*current_state_parts`) wrt - `current_state_parts`. Must have same shape as `current_state_parts`. The - only reason to specify this argument is to reduce TF graph size. - Default value: `None` (i.e., compute as needed). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_leapfrog_integrator"). - - Returns: - proposed_momentums: Updated value of the momentum. - proposed_state_parts: Tensor or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at each result step. Has same shape as - input `current_state_parts`. - proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `accepted_state`. - proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `accepted_state`. - - Raises: - ValueError: if `len(momentums) != len(state_parts)`. - ValueError: if `len(state_parts) != len(step_sizes)`. - ValueError: if `len(state_parts) != len(grads_target_log_prob)`. - TypeError: if `not target_log_prob.dtype.is_floating`. - """ - def _loop_body(step, - current_momentums, - current_state_parts, - ignore_current_target_log_prob, # pylint: disable=unused-argument - current_grads_target_log_prob): - return [step + 1] + list(_leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob)) - - with ops.name_scope( - name, "hmc_leapfrog_integrator", - [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, - current_target_log_prob, current_grads_target_log_prob]): - if len(current_momentums) != len(current_state_parts): - raise ValueError("`momentums` must be in one-to-one correspondence " - "with `state_parts`") - num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, - name="num_leapfrog_steps") - current_target_log_prob, current_grads_target_log_prob = ( - _maybe_call_fn_and_grads( - target_log_prob_fn, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob)) - return control_flow_ops.while_loop( - cond=lambda iter_, *args: iter_ < num_leapfrog_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - current_momentums, - current_state_parts, - current_target_log_prob, - current_grads_target_log_prob, - ], - back_prop=False)[1:] # Lop-off "iter_". - - -def _leapfrog_step(current_momentums, - target_log_prob_fn, - current_state_parts, - step_sizes, - current_grads_target_log_prob, - name=None): - """Applies one step of the leapfrog integrator.""" - with ops.name_scope( - name, "_leapfrog_step", - [current_momentums, current_state_parts, step_sizes, - current_grads_target_log_prob]): - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(current_momentums, - step_sizes, - current_grads_target_log_prob)] - proposed_state_parts = [x + ss * m for x, ss, m - in zip(current_state_parts, - step_sizes, - proposed_momentums)] - proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) - if not proposed_target_log_prob.dtype.is_floating: - raise TypeError("`target_log_prob_fn` must produce a `Tensor` " - "with `float` `dtype`.") - proposed_grads_target_log_prob = gradients_ops.gradients( - proposed_target_log_prob, proposed_state_parts) - if any(g is None for g in proposed_grads_target_log_prob): - raise ValueError( - "Encountered `None` gradient. Does your target `target_log_prob_fn` " - "access all `tf.Variable`s via `tf.get_variable`?\n" - " current_state_parts: {}\n" - " proposed_state_parts: {}\n" - " proposed_grads_target_log_prob: {}".format( - current_state_parts, - proposed_state_parts, - proposed_grads_target_log_prob)) - proposed_momentums = [m + 0.5 * ss * g for m, ss, g - in zip(proposed_momentums, - step_sizes, - proposed_grads_target_log_prob)] - return [ - proposed_momentums, - proposed_state_parts, - proposed_target_log_prob, - proposed_grads_target_log_prob, - ] - - -def _compute_energy_change(current_target_log_prob, - current_momentums, - proposed_target_log_prob, - proposed_momentums, - independent_chain_ndims, - name=None): - """Helper to `kernel` which computes the energy change.""" - with ops.name_scope( - name, "compute_energy_change", - ([current_target_log_prob, proposed_target_log_prob, - independent_chain_ndims] + - current_momentums + proposed_momentums)): - # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy - # since they're a mouthful and lets us inline more. - lk0, lk1 = [], [] - for current_momentum, proposed_momentum in zip(current_momentums, - proposed_momentums): - axis = math_ops.range(independent_chain_ndims, - array_ops.rank(current_momentum)) - lk0.append(_log_sum_sq(current_momentum, axis)) - lk1.append(_log_sum_sq(proposed_momentum, axis)) - - lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), - axis=-1) - lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), - axis=-1) - lp0 = -current_target_log_prob # log_potential - lp1 = -proposed_target_log_prob # proposed_log_potential - x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], - axis=-1) - - # The sum is NaN if any element is NaN or we see both +Inf and -Inf. - # Thus we will replace such rows with infinite energy change which implies - # rejection. Recall that float-comparisons with NaN are always False. - is_sum_determinate = ( - math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & - math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) - is_sum_determinate = array_ops.tile( - is_sum_determinate[..., array_ops.newaxis], - multiples=array_ops.concat([ - array_ops.ones(array_ops.rank(is_sum_determinate), - dtype=dtypes.int32), - [4], - ], axis=0)) - x = array_ops.where(is_sum_determinate, - x, - array_ops.fill(array_ops.shape(x), - value=x.dtype.as_numpy_dtype(np.inf))) - - return math_ops.reduce_sum(x, axis=-1) - - -def _choose(is_accepted, - accepted, - rejected, - independent_chain_ndims, - name=None): - """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" - def _expand_is_accepted_like(x): - with ops.name_scope("_choose"): - expand_shape = array_ops.concat([ - array_ops.shape(is_accepted), - array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], - dtype=dtypes.int32), - ], axis=0) - multiples = array_ops.concat([ - array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), - array_ops.shape(x)[independent_chain_ndims:], - ], axis=0) - m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), - multiples) - m.set_shape(x.shape) - return m - with ops.name_scope(name, "_choose", values=[ - is_accepted, accepted, rejected, independent_chain_ndims]): - return array_ops.where(_expand_is_accepted_like(accepted), - accepted, - rejected) - - -def _maybe_call_fn_and_grads(fn, - fn_arg_list, - fn_result=None, - grads_fn_result=None, - description="target_log_prob"): - """Helper which computes `fn_result` and `grads` if needed.""" - fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) - else [fn_arg_list]) - if fn_result is None: - fn_result = fn(*fn_arg_list) - if not fn_result.dtype.is_floating: - raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( - description)) - if grads_fn_result is None: - grads_fn_result = gradients_ops.gradients( - fn_result, fn_arg_list) - if len(fn_arg_list) != len(grads_fn_result): - raise ValueError("`{}` must be in one-to-one correspondence with " - "`grads_{}`".format(*[description]*2)) - if any(g is None for g in grads_fn_result): - raise ValueError("Encountered `None` gradient.") - return fn_result, grads_fn_result - - -def _prepare_args(target_log_prob_fn, state, step_size, - target_log_prob=None, grads_target_log_prob=None, - maybe_expand=False, description="target_log_prob"): - """Helper which processes input args to meet list-like assumptions.""" - state_parts = list(state) if _is_list_like(state) else [state] - state_parts = [ops.convert_to_tensor(s, name="state") - for s in state_parts] - target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( - target_log_prob_fn, - state_parts, - target_log_prob, - grads_target_log_prob, - description) - step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] - step_sizes = [ - ops.convert_to_tensor( - s, name="step_size", dtype=target_log_prob.dtype) - for s in step_sizes] - if len(step_sizes) == 1: - step_sizes *= len(state_parts) - if len(state_parts) != len(step_sizes): - raise ValueError("There should be exactly one `step_size` or it should " - "have same length as `current_state`.") - maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0] - return [ - maybe_flatten(state_parts), - maybe_flatten(step_sizes), - target_log_prob, - grads_target_log_prob, - ] - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _log_sum_sq(x, axis=None): - """Computes log(sum(x**2)).""" - return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py deleted file mode 100644 index a742b7c1aa593d6c08bf9d8d597c99c9fc4e7aed..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Probabilistic neural layers. - -See ${python/contrib.bayesflow.layers}. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import * -from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational import * -from tensorflow.contrib.bayesflow.python.ops.layers_util import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'Convolution1DReparameterization', - 'Convolution2DReparameterization', - 'Convolution3DReparameterization', - 'Convolution1DFlipout', - 'Convolution2DFlipout', - 'Convolution3DFlipout', - 'Conv1DReparameterization', - 'Conv2DReparameterization', - 'Conv3DReparameterization', - 'Conv1DFlipout', - 'Conv2DFlipout', - 'Conv3DFlipout', - 'convolution1d_reparameterization', - 'convolution2d_reparameterization', - 'convolution3d_reparameterization', - 'convolution1d_flipout', - 'convolution2d_flipout', - 'convolution3d_flipout', - 'conv1d_reparameterization', - 'conv2d_reparameterization', - 'conv3d_reparameterization', - 'conv1d_flipout', - 'conv2d_flipout', - 'conv3d_flipout', - 'DenseReparameterization', - 'DenseLocalReparameterization', - 'DenseFlipout', - 'dense_reparameterization', - 'dense_local_reparameterization', - 'dense_flipout', - 'default_loc_scale_fn', - 'default_mean_field_normal_fn', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py deleted file mode 100644 index 7723cfb442712626ff415f1412e3362f2392ce9f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py +++ /dev/null @@ -1,2943 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Convolutional variational layer classes and their functional aliases. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.layers import utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util - - -class _ConvVariational(layers_lib.Layer): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - """ - - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(_ConvVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self.rank = rank - self.filters = filters - self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size") - self.strides = utils.normalize_tuple(strides, rank, "strides") - self.padding = utils.normalize_padding(padding) - self.data_format = utils.normalize_data_format(data_format) - self.dilation_rate = utils.normalize_tuple( - dilation_rate, rank, "dilation_rate") - self.activation = activation - self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2) - self.kernel_posterior_fn = kernel_posterior_fn - self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn - self.kernel_prior_fn = kernel_prior_fn - self.kernel_divergence_fn = kernel_divergence_fn - self.bias_posterior_fn = bias_posterior_fn - self.bias_posterior_tensor_fn = bias_posterior_tensor_fn - self.bias_prior_fn = bias_prior_fn - self.bias_divergence_fn = bias_divergence_fn - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - if self.data_format == "channels_first": - channel_axis = 1 - else: - channel_axis = -1 - if input_shape[channel_axis].value is None: - raise ValueError("The channel dimension of the inputs " - "should be defined. Found `None`.") - input_dim = input_shape[channel_axis].value - kernel_shape = self.kernel_size + (input_dim, self.filters) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel_posterior = self.kernel_posterior_fn( - dtype, kernel_shape, "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel_prior_fn is None: - self.kernel_prior = None - else: - self.kernel_prior = self.kernel_prior_fn( - dtype, kernel_shape, "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias_posterior_fn is None: - self.bias_posterior = None - else: - self.bias_posterior = self.bias_posterior_fn( - dtype, (self.filters,), "bias_posterior", - self.trainable, self.add_variable) - - if self.bias_prior_fn is None: - self.bias_prior = None - else: - self.bias_prior = self.bias_prior_fn( - dtype, (self.filters,), "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2, - axes={channel_axis: input_dim}) - self._convolution_op = nn_ops.Convolution( - input_shape, - filter_shape=tensor_shape.TensorShape(kernel_shape), - dilation_rate=self.dilation_rate, - strides=self.strides, - padding=self.padding.upper(), - data_format=utils.convert_data_format(self.data_format, - self.rank + 2)) - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) - if not self._built_kernel_divergence: - kernel_posterior = self.kernel_posterior - kernel_prior = self.kernel_prior - if isinstance(self.kernel_posterior, independent_lib.Independent): - kernel_posterior = kernel_posterior.distribution - if isinstance(self.kernel_prior, independent_lib.Independent): - kernel_prior = kernel_prior.distribution - self._apply_divergence(self.kernel_divergence_fn, - kernel_posterior, - kernel_prior, - self.kernel_posterior_tensor, - name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - bias_posterior = self.bias_posterior - bias_prior = self.bias_prior - if isinstance(self.bias_posterior, independent_lib.Independent): - bias_posterior = bias_posterior.distribution - if isinstance(self.bias_prior, independent_lib.Independent): - bias_prior = bias_prior.distribution - self._apply_divergence(self.bias_divergence_fn, - bias_posterior, - bias_prior, - self.bias_posterior_tensor, - name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_bias(self, inputs): - if self.bias_posterior is None: - self.bias_posterior_tensor = None - return inputs - self.bias_posterior_tensor = self.bias_posterior_tensor_fn( - self.bias_posterior) - outputs = inputs - if self.data_format == "channels_first": - if self.rank == 1: - # nn.bias_add does not accept a 1D input tensor. - bias = array_ops.reshape(self.bias_posterior_tensor, - (1, self.filters, 1)) - outputs += bias - if self.rank == 2: - outputs = nn.bias_add(outputs, - self.bias_posterior_tensor, - data_format="NCHW") - if self.rank == 3: - # As of Mar 2017, direct addition is significantly slower than - # bias_add when computing gradients. To use bias_add, we collapse Z - # and Y into a single dimension to obtain a 4D input tensor. - outputs_shape = outputs.shape.as_list() - outputs_4d = array_ops.reshape(outputs, - [outputs_shape[0], outputs_shape[1], - outputs_shape[2] * outputs_shape[3], - outputs_shape[4]]) - outputs_4d = nn.bias_add(outputs_4d, - self.bias_posterior_tensor, - data_format="NCHW") - outputs = array_ops.reshape(outputs_4d, outputs_shape) - else: - outputs = nn.bias_add(outputs, - self.bias_posterior_tensor, - data_format="NHWC") - return outputs - - def _apply_divergence(self, divergence_fn, posterior, prior, - posterior_tensor, name): - if (divergence_fn is None or - posterior is None or - prior is None): - divergence = None - return - divergence = standard_ops.identity( - divergence_fn( - posterior, prior, posterior_tensor), - name=name) - self.add_loss(divergence) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if self.data_format == "channels_last": - space = input_shape[1:-1] - new_space = [] - for i in range(len(space)): - new_dim = utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - return tensor_shape.TensorShape([input_shape[0]] + new_space + - [self.filters]) - else: - space = input_shape[2:] - new_space = [] - for i in range(len(space)): - new_dim = utils.conv_output_length( - space[i], - self.kernel_size[i], - padding=self.padding, - stride=self.strides[i], - dilation=self.dilation_rate[i]) - new_space.append(new_dim) - return tensor_shape.TensorShape([input_shape[0], self.filters] + - new_space) - - -class _ConvReparameterization(_ConvVariational): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(_ConvReparameterization, self).__init__( - rank=rank, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - def _apply_variational_kernel(self, inputs): - self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( - self.kernel_posterior) - self.kernel_posterior_affine = None - self.kernel_posterior_affine_tensor = None - outputs = self._convolution_op(inputs, self.kernel_posterior_tensor) - return outputs - - -class Conv1DReparameterization(_ConvReparameterization): - """1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.Conv1DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(Conv1DReparameterization, self).__init__( - rank=1, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -def conv1d_reparameterization( - inputs, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.conv1d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = Conv1DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv2DReparameterization(_ConvReparameterization): - """2D convolution layer (e.g. spatial convolution over images). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.Conv2DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(Conv2DReparameterization, self).__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -def conv2d_reparameterization( - inputs, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.conv2d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = Conv2DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv3DReparameterization(_ConvReparameterization): - """3D convolution layer (e.g. spatial convolution over volumes). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.Conv3DReparameterization(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(Conv3DReparameterization, self).__init__( - rank=3, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - - -def conv3d_reparameterization( - inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the reparameterization - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.conv3d_reparameterization(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = Conv3DReparameterization( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class _ConvFlipout(_ConvVariational): - """Abstract nD convolution layer (private, used as implementation base). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - rank: Python integer, dimensionality of convolution. - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - rank, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(_ConvFlipout, self).__init__( - rank=rank, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, **kwargs) - self.seed = seed - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`{}` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format( - type(self).__name__, self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), - scale=self.kernel_posterior.distribution.scale) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - - outputs = self._convolution_op( - inputs, self.kernel_posterior.distribution.loc) - - input_shape = array_ops.shape(inputs) - output_shape = array_ops.shape(outputs) - batch_shape = array_ops.expand_dims(input_shape[0], 0) - channels = input_shape[-1] - - sign_input = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(channels, 0)], 0), - dtype=inputs.dtype, - seed=self.seed) - sign_output = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(self.filters, 0)], 0), - dtype=inputs.dtype, - seed=distribution_util.gen_new_seed( - self.seed, salt="conv_flipout")) - for _ in range(self.rank): - sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) - sign_output = array_ops.expand_dims(sign_output, 1) - - sign_input = array_ops.tile( # tile for element-wise op broadcasting - sign_input, - [1] + [input_shape[i + 1] for i in range(self.rank)] + [1]) - sign_output = array_ops.tile( - sign_output, - [1] + [output_shape[i + 1] for i in range(self.rank)] + [1]) - - perturbed_inputs = self._convolution_op( - inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output - - outputs += perturbed_inputs - return outputs - - -class Conv1DFlipout(_ConvFlipout): - """1D convolution layer (e.g. temporal convolution) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.Conv1DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(Conv1DFlipout, self).__init__( - rank=1, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -def conv1d_flipout( - inputs, - filters, - kernel_size, - strides=1, - padding="valid", - data_format="channels_last", - dilation_rate=1, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Functional interface for 1D convolution layer (e.g. temporal convolution). - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 128, 1]) - net = tfp.layers.conv1d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.reshape(net, [-1, 128 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = Conv1DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv2DFlipout(_ConvFlipout): - """2D convolution layer (e.g. spatial convolution over images) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.Conv2DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(Conv2DFlipout, self).__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -def conv2d_flipout( - inputs, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Functional interface for the 2D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 32, 32, 3]) - net = tfp.layers.conv2d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 8 * 8 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = Conv2DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class Conv3DFlipout(_ConvFlipout): - """3D convolution layer (e.g. spatial convolution over volumes) with Flipout. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - - Properties: - filters: Python integer, dimensionality of the output space. - kernel_size: Size of the convolution window. - strides: Stride length of convolution. - padding: Python string describing padding approach. - data_format: Python string describing input data's dimensions. - dilation_rate: Dilation rate for an atrous convolution. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.Conv3DFlipout(64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu)(net) - net = tf.layers.MaxPooling2D(pool_size=2, - strides=2, - padding="SAME")(net) - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(Conv3DFlipout, self).__init__( - rank=3, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, **kwargs) - - -def conv3d_flipout( - inputs, - filters, - kernel_size, - strides=(1, 1, 1), - padding="valid", - data_format="channels_last", - dilation_rate=(1, 1, 1), - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Functional interface for the 3D convolution layer. - - This layer creates a convolution kernel that is convolved - (actually cross-correlated) with the layer input to produce a tensor of - outputs. It may also include a bias addition and activation function - on the outputs. It assumes the `kernel` and/or `bias` are drawn from - distributions. - - By default, the layer implements a stochastic forward pass via - sampling from the kernel and bias posteriors, - ```none - outputs = f(inputs; kernel, bias), kernel, bias ~ posterior - ``` - where f denotes the layer's calculation. It uses the Flipout - estimator [1], which performs a Monte Carlo approximation of the - distribution integrating over the `kernel` and `bias`. Flipout uses - roughly twice as many floating point operations as the - reparameterization estimator but has the advantage of significantly - lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Arguments: - inputs: Tensor input. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - reuse: Boolean, whether to reuse the weights of a previous layer - by the same name. - - Returns: - Output tensor. - - Raises: - ValueError: if eager execution is enabled. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tf.reshape(features, [-1, 256, 32, 32, 3]) - net = tfp.layers.conv3d_flipout(net, - filters=64, - kernel_size=5, - padding="SAME", - activation=tf.nn.relu) - net = tf.layers.max_pooling2d(net, - pool_size=2, - strides=2, - padding="SAME") - net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = Conv3DFlipout( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -# Aliases - -Convolution1DReparameterization = Conv1DReparameterization -Convolution2DReparameterization = Conv2DReparameterization -Convolution3DReparameterization = Conv3DReparameterization -convolution1d_reparameterization = conv1d_reparameterization -convolution2d_reparameterization = conv2d_reparameterization -convolution3d_reparameterization = conv3d_reparameterization -Convolution1DFlipout = Conv1DFlipout -Convolution2DFlipout = Conv2DFlipout -Convolution3DFlipout = Conv3DFlipout -convolution1d_flipout = conv1d_flipout -convolution2d_flipout = conv2d_flipout -convolution3d_flipout = conv3d_flipout diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py deleted file mode 100644 index 591a8e553de0c194786c7ee8693665f762711b2d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py +++ /dev/null @@ -1,1176 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Dense Bayesian layer using KL-divergence based variational inference. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import layers_util -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util - - -class _DenseVariational(layers_lib.Layer): - """Abstract densely-connected class (private, used as implementation base). - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(_DenseVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self.units = units - self.activation = activation - self.input_spec = layers_lib.InputSpec(min_ndim=2) - self.kernel_posterior_fn = kernel_posterior_fn - self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn - self.kernel_prior_fn = kernel_prior_fn - self.kernel_divergence_fn = kernel_divergence_fn - self.bias_posterior_fn = bias_posterior_fn - self.bias_posterior_tensor_fn = bias_posterior_tensor_fn - self.bias_prior_fn = bias_prior_fn - self.bias_divergence_fn = bias_divergence_fn - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - in_size = input_shape.with_rank_at_least(2)[-1].value - if in_size is None: - raise ValueError("The last dimension of the inputs to `Dense` " - "should be defined. Found `None`.") - self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel_posterior = self.kernel_posterior_fn( - dtype, [in_size, self.units], "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel_prior_fn is None: - self.kernel_prior = None - else: - self.kernel_prior = self.kernel_prior_fn( - dtype, [in_size, self.units], "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias_posterior_fn is None: - self.bias_posterior = None - else: - self.bias_posterior = self.bias_posterior_fn( - dtype, [self.units], "bias_posterior", - self.trainable, self.add_variable) - - if self.bias_prior_fn is None: - self.bias_prior = None - else: - self.bias_prior = self.bias_prior_fn( - dtype, [self.units], "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) # pylint: disable=not-callable - if not self._built_kernel_divergence: - kernel_posterior = self.kernel_posterior - kernel_prior = self.kernel_prior - if isinstance(self.kernel_posterior, independent_lib.Independent): - kernel_posterior = kernel_posterior.distribution - if isinstance(self.kernel_prior, independent_lib.Independent): - kernel_prior = kernel_prior.distribution - self._apply_divergence(self.kernel_divergence_fn, - kernel_posterior, - kernel_prior, - self.kernel_posterior_tensor, - name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - bias_posterior = self.bias_posterior - bias_prior = self.bias_prior - if isinstance(self.bias_posterior, independent_lib.Independent): - bias_posterior = bias_posterior.distribution - if isinstance(self.bias_prior, independent_lib.Independent): - bias_prior = bias_prior.distribution - self._apply_divergence(self.bias_divergence_fn, - bias_posterior, - bias_prior, - self.bias_posterior_tensor, - name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_bias(self, inputs): - if self.bias_posterior is None: - self.bias_posterior_tensor = None - return inputs - self.bias_posterior_tensor = self.bias_posterior_tensor_fn( - self.bias_posterior) - return nn.bias_add(inputs, self.bias_posterior_tensor) - - def _apply_divergence(self, divergence_fn, posterior, prior, - posterior_tensor, name): - if (divergence_fn is None or - posterior is None or - prior is None): - divergence = None - return - divergence = standard_ops.identity( - divergence_fn( - posterior, prior, posterior_tensor), - name=name) - self.add_loss(divergence) - - def _matmul(self, inputs, kernel): - if inputs.shape.ndims <= 2: - return standard_ops.matmul(inputs, kernel) - # To handle broadcasting, we must use `tensordot`. - return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) - if input_shape[-1].value is None: - raise ValueError( - "The innermost dimension of input_shape must be defined, " - "but saw: {}".format(input_shape)) - return input_shape[:-1].concatenate(self.units) - - -class DenseReparameterization(_DenseVariational): - """Densely-connected layer class with reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the reparameterization estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseReparameterization( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(DenseReparameterization, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - - def _apply_variational_kernel(self, inputs): - self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( - self.kernel_posterior) - self.kernel_posterior_affine = None - self.kernel_posterior_affine_tensor = None - return self._matmul(inputs, self.kernel_posterior_tensor) - - -def dense_reparameterization( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Densely-connected layer with reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the reparameterization estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_reparameterization( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Auto-Encoding Variational Bayes." - Diederik P. Kingma, Max Welling. - International Conference on Learning Representations, 2014. - """ - layer = DenseReparameterization( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class DenseLocalReparameterization(_DenseVariational): - """Densely-connected layer class with local reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the local reparameterization estimator [1], which performs a - Monte Carlo approximation of the distribution on the hidden units - induced by the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseLocalReparameterization( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseLocalReparameterization(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses local reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Variational Dropout and the Local Reparameterization Trick." - Diederik P. Kingma, Tim Salimans, Max Welling. - Neural Information Processing Systems, 2015. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(DenseLocalReparameterization, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`DenseLocalReparameterization` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format(self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), - scale=standard_ops.sqrt(self._matmul( - standard_ops.square(inputs), - standard_ops.square(self.kernel_posterior.distribution.scale)))) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - return self.kernel_posterior_affine_tensor - - -def dense_local_reparameterization( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Densely-connected layer with local reparameterization estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the local reparameterization estimator [1], which performs a - Monte Carlo approximation of the distribution on the hidden units - induced by the `kernel` and `bias`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_local_reparameterization( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_local_reparameterization(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses local reparameterization gradients to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Variational Dropout and the Local Reparameterization Trick." - Diederik P. Kingma, Tim Salimans, Max Welling. - Neural Information Processing Systems, 2015. - """ - layer = DenseLocalReparameterization( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class DenseFlipout(_DenseVariational): - """Densely-connected layer class with Flipout estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the Flipout estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. Flipout uses roughly twice as many floating point operations - as the reparameterization estimator but has the advantage of - significantly lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_posterior_fn: `callable` returning posterior. - kernel_posterior_tensor_fn: `callable` operating on posterior. - kernel_prior_fn: `callable` returning prior. - kernel_divergence_fn: `callable` returning divergence. - bias_posterior_fn: `callable` returning posterior. - bias_posterior_tensor_fn: `callable` operating on posterior. - bias_prior_fn: `callable` returning prior. - bias_divergence_fn: `callable` returning divergence. - seed: Python integer, used to create random seeds. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.DenseFlipout( - 512, activation=tf.nn.relu)(features) - logits = tfp.layers.DenseFlipout(10)(net) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - **kwargs): - super(DenseFlipout, self).__init__( - units=units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - **kwargs) - self.seed = seed - - def _apply_variational_kernel(self, inputs): - if (not isinstance(self.kernel_posterior, independent_lib.Independent) or - not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): - raise TypeError( - "`DenseFlipout` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Independent(tf.distributions.Normal)` " - "(saw: \"{}\").".format(self.kernel_posterior.name)) - self.kernel_posterior_affine = normal_lib.Normal( - loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), - scale=self.kernel_posterior.distribution.scale) - self.kernel_posterior_affine_tensor = ( - self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) - self.kernel_posterior_tensor = None - - input_shape = array_ops.shape(inputs) - batch_shape = input_shape[:-1] - - sign_input = layers_util.random_sign( - input_shape, - dtype=inputs.dtype, - seed=self.seed) - sign_output = layers_util.random_sign( - array_ops.concat([batch_shape, - array_ops.expand_dims(self.units, 0)], 0), - dtype=inputs.dtype, - seed=distribution_util.gen_new_seed( - self.seed, salt="dense_flipout")) - perturbed_inputs = self._matmul( - inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output - - outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc) - outputs += perturbed_inputs - return outputs - - -def dense_flipout( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=layers_util.default_mean_field_normal_fn( - is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - seed=None, - name=None, - reuse=None): - """Densely-connected layer with Flipout estimator. - - This layer implements the Bayesian variational inference analogue to - a dense layer by assuming the `kernel` and/or the `bias` are drawn - from distributions. By default, the layer implements a stochastic - forward pass via sampling from the kernel and bias posteriors, - - ```none - kernel, bias ~ posterior - outputs = activation(matmul(inputs, kernel) + bias) - ``` - - It uses the Flipout estimator [1], which performs a Monte Carlo - approximation of the distribution integrating over the `kernel` and - `bias`. Flipout uses roughly twice as many floating point operations - as the reparameterization estimator but has the advantage of - significantly lower variance. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - distributions. - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - - #### Examples - - We illustrate a Bayesian neural network with [variational inference]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), - assuming a dataset of `features` and `labels`. - - ```python - tfp = tf.contrib.bayesflow - - net = tfp.layers.dense_flipout( - features, 512, activation=tf.nn.relu) - logits = tfp.layers.dense_flipout(net, 10) - neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( - labels=labels, logits=logits) - kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) - loss = neg_log_likelihood + kl - train_op = tf.train.AdamOptimizer().minimize(loss) - ``` - - It uses the Flipout gradient estimator to minimize the - Kullback-Leibler divergence up to a constant, also known as the - negative Evidence Lower Bound. It consists of the sum of two terms: - the expected negative log-likelihood, which we approximate via - Monte Carlo; and the KL divergence, which is added via regularizer - terms which are arguments to the layer. - - [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on - Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb - """ - layer = DenseFlipout( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - seed=seed, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py deleted file mode 100644 index 8c1fb203f7328e8260e49b4326d813fbe133613e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_util.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for probabilistic layers. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib -from tensorflow.contrib.distributions.python.ops import independent as independent_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops.distributions import normal as normal_lib - - -def default_loc_scale_fn( - is_singular=False, - loc_initializer=init_ops.random_normal_initializer(stddev=0.1), - untransformed_scale_initializer=init_ops.random_normal_initializer( - mean=-3., stddev=0.1), - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. - - This function produces a closure which produces `loc`, `scale` using - `tf.get_variable`. The closure accepts the following arguments: - - dtype: Type of parameter's event. - shape: Python `list`-like representing the parameter's event shape. - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` indicating if `scale is None`. Default: `False`. - loc_initializer: Initializer function for the `loc` parameters. - The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. Default value: `tf.random_normal_initializer(mean=-3., - stddev=0.1)`. This implies the softplus transformed result has mean - approximately `0.05` and std. deviation approximately `0.005`. - loc_regularizer: Regularizer function for the `loc` parameters. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. The default (`None`) is to use the `tf.get_variable` default. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. The default - (`None`) is to use the `tf.get_variable` default. - - Returns: - default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` - parameters from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates `loc`, `scale` parameters.""" - loc = add_variable_fn( - name=name + "_loc", - shape=shape, - initializer=loc_initializer, - regularizer=loc_regularizer, - constraint=loc_constraint, - dtype=dtype, - trainable=trainable) - if is_singular: - return loc, None - untransformed_scale = add_variable_fn( - name=name + "_untransformed_scale", - shape=shape, - initializer=untransformed_scale_initializer, - regularizer=untransformed_scale_regularizer, - constraint=untransformed_scale_constraint, - dtype=dtype, - trainable=trainable) - scale = (np.finfo(dtype.as_numpy_dtype).eps + - nn_ops.softplus(untransformed_scale)) - return loc, scale - return _fn - - -def default_mean_field_normal_fn( - is_singular=False, - loc_initializer=None, - untransformed_scale_initializer=None, - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Creates a function to build Normal distributions with trainable params. - - This function produces a closure which produces `tf.distributions.Normal` - parameterized by a loc` and `scale` each created using `tf.get_variable`. The - produced closure accepts the following arguments: - - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` if `True`, forces the special case limit of - `scale->0`, i.e., a `Deterministic` distribution. - loc_initializer: Initializer function for the `loc` parameters. - If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - loc_regularizer: Regularizer function for the `loc` parameters. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. - - Returns: - make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` - using from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - loc_scale_fn_ = default_loc_scale_fn( - is_singular, - loc_initializer, - untransformed_scale_initializer, - loc_regularizer, - untransformed_scale_regularizer, - loc_constraint, - untransformed_scale_constraint) - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates multivariate `Deterministic` or `Normal` distribution.""" - loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) - if scale is None: - dist = deterministic_lib.Deterministic(loc=loc) - else: - dist = normal_lib.Normal(loc=loc, scale=scale) - reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] - return independent_lib.Independent( - dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) - return _fn - - -def random_sign(shape, dtype=dtypes.float32, seed=None): - """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" - random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, - dtype=dtypes.int32, - seed=seed) - return math_ops.cast(2 * random_bernoulli - 1, dtype) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py deleted file mode 100644 index 0424b6952bc89ce7fe5b00b0135c9a5fe1faa8cf..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for Markov Chain Monte Carlo (MCMC) sampling. - -@@effective_sample_size -@@potential_scale_reduction -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops import sample_stats -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - -__all__ = [ - "effective_sample_size", - "potential_scale_reduction", -] - - -def effective_sample_size(states, - filter_threshold=0., - filter_beyond_lag=None, - name=None): - """Estimate a lower bound on effective sample size for each independent chain. - - Roughly speaking, "effective sample size" (ESS) is the size of an iid sample - with the same variance as `state`. - - More precisely, given a stationary sequence of possibly correlated random - variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number - such that - - ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.``` - - If the sequence is uncorrelated, `ESS = N`. In general, one should expect - `ESS <= N`, with more highly correlated sequences having smaller `ESS`. - - #### Example of using ESS to estimate standard error. - - ``` - tfd = tf.contrib.distributions - tfb = tf.contrib.bayesflow - - target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) - - # Get 1000 states from one chain. - states = tfb.hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target.log_prob, - current_state=tf.constant([0., 0.]), - step_size=0.05, - num_leapfrog_steps=20, - num_burnin_steps=200) - states.shape - ==> (1000, 2) - - ess = effective_sample_size(states) - ==> Shape (2,) Tensor - - mean, variance = tf.nn.moments(states, axis=0) - standard_error = tf.sqrt(variance / ess) - ``` - - Some math shows that, with `R_k` the auto-correlation sequence, - `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have - - ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]``` - - This function estimates the above by first estimating the auto-correlation. - Since `R_k` must be estimated using only `N - k` samples, it becomes - progressively noisier for larger `k`. For this reason, the summation over - `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many - MCMC methods generate chains where `R_k > 0`, a reasonable critera is to - truncate at the first index where the estimated auto-correlation becomes - negative. - - The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to - remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning - terms are removed if they were to be filtered under the `filter_beyond_lag` OR - `filter_threshold` criteria. - - Args: - states: `Tensor` or list of `Tensor` objects. Dimension zero should index - identically distributed states. - filter_threshold: `Tensor` or list of `Tensor` objects. - Must broadcast with `state`. The auto-correlation sequence is truncated - after the first appearance of a term less than `filter_threshold`. - Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, - setting to any number less than `-1` has the same effect. - filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be - `int`-like and scalar valued. The auto-correlation sequence is truncated - to this length. Setting to `None` means we do not filter based on number - of lags. - name: `String` name to prepend to created ops. - - Returns: - ess: `Tensor` or list of `Tensor` objects. The effective sample size of - each component of `states`. Shape will be `states.shape[1:]`. - - Raises: - ValueError: If `states` and `filter_threshold` or `states` and - `filter_beyond_lag` are both lists with different lengths. - """ - states_was_list = _is_list_like(states) - - # Convert all args to lists. - if not states_was_list: - states = [states] - - filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag, - "filter_beyond_lag") - filter_threshold = _broadcast_maybelist_arg(states, filter_threshold, - "filter_threshold") - - # Process items, one at a time. - with ops.name_scope(name, "effective_sample_size"): - ess_list = [ - _effective_sample_size_single_state(s, ml, mlt) - for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold) - ] - - if states_was_list: - return ess_list - return ess_list[0] - - -def _effective_sample_size_single_state(states, filter_beyond_lag, - filter_threshold): - """ESS computation for one single Tensor argument.""" - - with ops.name_scope( - "effective_sample_size_single_state", - values=[states, filter_beyond_lag, filter_threshold]): - - states = ops.convert_to_tensor(states, name="states") - dt = states.dtype - - # filter_beyond_lag == None ==> auto_corr is the full sequence. - auto_corr = sample_stats.auto_correlation( - states, axis=0, max_lags=filter_beyond_lag) - if filter_threshold is not None: - filter_threshold = ops.convert_to_tensor( - filter_threshold, dtype=dt, name="filter_threshold") - # Get a binary mask to zero out values of auto_corr below the threshold. - # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, - # mask[i, ...] = 0, otherwise. - # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] - # Building step by step, - # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. - # Step 1: mask = [False, False, True, False] - mask = auto_corr < filter_threshold - # Step 2: mask = [0, 0, 1, 1] - mask = math_ops.cast(mask, dtype=dt) - # Step 3: mask = [0, 0, 1, 2] - mask = math_ops.cumsum(mask, axis=0) - # Step 4: mask = [1, 1, 0, 0] - mask = math_ops.maximum(1. - mask, 0.) - auto_corr *= mask - - # With R[k] := auto_corr[k, ...], - # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} - # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) - # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} - # where M is the filter_beyond_lag truncation point chosen above. - - # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total - # ndims the same as auto_corr - n = _axis_size(states, axis=0) - k = math_ops.range(0., _axis_size(auto_corr, axis=0)) - nk_factor = (n - k) / n - if auto_corr.shape.ndims is not None: - new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) - else: - new_shape = array_ops.concat( - ([-1], - array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)), - axis=0) - nk_factor = array_ops.reshape(nk_factor, new_shape) - - return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0)) - - -def potential_scale_reduction(chains_states, - independent_chain_ndims=1, - name=None): - """Gelman and Rubin's potential scale reduction factor for chain convergence. - - Given `N > 1` states from each of `C > 1` independent chains, the potential - scale reduction factor, commonly referred to as R-hat, measures convergence of - the chains (to the same target) by testing for equality of means. - Specifically, R-hat measures the degree to which variance (of the means) - between chains exceeds what one would expect if the chains were identically - distributed. See [1], [2]. - - Some guidelines: - - * The initial state of the chains should be drawn from a distribution - overdispersed with respect to the target. - * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1. - Before that, R-hat > 1 (except in pathological cases, e.g. if the chain - paths were identical). - * The above holds for any number of chains `C > 1`. Increasing `C` does - improves effectiveness of the diagnostic. - * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of - course this is problem depedendent. See [2]. - * R-hat only measures non-convergence of the mean. If higher moments, or other - statistics are desired, a different diagnostic should be used. See [2]. - - #### Examples - - Diagnosing convergence by monitoring 10 chains that each attempt to - sample from a 2-variate normal. - - ```python - tfd = tf.contrib.distributions - tfb = tf.contrib.bayesflow - - target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) - - # Get 10 (2x) overdispersed initial states. - initial_state = target.sample(10) * 2. - ==> (10, 2) - - # Get 1000 samples from the 10 independent chains. - chains_states, _ = tfb.hmc.sample_chain( - num_results=1000, - target_log_prob_fn=target.log_prob, - current_state=initial_state, - step_size=0.05, - num_leapfrog_steps=20, - num_burnin_steps=200) - chains_states.shape - ==> (1000, 10, 2) - - rhat = tfb.mcmc_diagnostics.potential_scale_reduction( - chains_states, independent_chain_ndims=1) - - # The second dimension needed a longer burn-in. - rhat.eval() - ==> [1.05, 1.3] - ``` - - To see why R-hat is reasonable, let `X` be a random variable drawn uniformly - from the combined states (combined over all chains). Then, in the limit - `N, C --> infinity`, with `E`, `Var` denoting expectation and variance, - - ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].``` - - Using the law of total variance, the numerator is the variance of the combined - states, and the denominator is the total variance minus the variance of the - the individual chain means. If the chains are all drawing from the same - distribution, they will have the same mean, and thus the ratio should be one. - - [1] "Inference from Iterative Simulation Using Multiple Sequences" - Andrew Gelman and Donald B. Rubin - Statist. Sci. Volume 7, Number 4 (1992), 457-472. - [2] "General Methods for Monitoring Convergence of Iterative Simulations" - Stephen P. Brooks and Andrew Gelman - Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4. - - Args: - chains_states: `Tensor` or Python `list` of `Tensor`s representing the - state(s) of a Markov Chain at each result step. The `ith` state is - assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. - Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. - Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent - chains to be tested for convergence to the same target. - The remaining dimensions, `A`, can have any shape (even empty). - independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the - number of giving the number of dimensions, from `dim = 1` to `dim = D`, - holding independent chain results to be tested for convergence. - name: `String` name to prepend to created ops. Default: - `potential_scale_reduction`. - - Returns: - `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for - the state(s). Same `dtype` as `state`, and shape equal to - `state.shape[1 + independent_chain_ndims:]`. - - Raises: - ValueError: If `independent_chain_ndims < 1`. - """ - chains_states_was_list = _is_list_like(chains_states) - if not chains_states_was_list: - chains_states = [chains_states] - - # tensor_util.constant_value returns None iff a constant value (as a numpy - # array) is not efficiently computable. Therefore, we try constant_value then - # check for None. - icn_const_ = tensor_util.constant_value( - ops.convert_to_tensor(independent_chain_ndims)) - if icn_const_ is not None: - independent_chain_ndims = icn_const_ - if icn_const_ < 1: - raise ValueError( - "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format( - independent_chain_ndims)) - - with ops.name_scope(name, "potential_scale_reduction"): - rhat_list = [ - _potential_scale_reduction_single_state(s, independent_chain_ndims) - for s in chains_states - ] - - if chains_states_was_list: - return rhat_list - return rhat_list[0] - - -def _potential_scale_reduction_single_state(state, independent_chain_ndims): - """potential_scale_reduction for one single state `Tensor`.""" - with ops.name_scope( - "potential_scale_reduction_single_state", - values=[state, independent_chain_ndims]): - # We assume exactly one leading dimension indexes e.g. correlated samples - # from each Markov chain. - state = ops.convert_to_tensor(state, name="state") - sample_ndims = 1 - - sample_axis = math_ops.range(0, sample_ndims) - chain_axis = math_ops.range(sample_ndims, - sample_ndims + independent_chain_ndims) - sample_and_chain_axis = math_ops.range( - 0, sample_ndims + independent_chain_ndims) - - n = _axis_size(state, sample_axis) - m = _axis_size(state, chain_axis) - - # In the language of [2], - # B / n is the between chain variance, the variance of the chain means. - # W is the within sequence variance, the mean of the chain variances. - b_div_n = _reduce_variance( - math_ops.reduce_mean(state, sample_axis, keepdims=True), - sample_and_chain_axis, - biased=False) - w = math_ops.reduce_mean( - _reduce_variance(state, sample_axis, keepdims=True, biased=True), - sample_and_chain_axis) - - # sigma^2_+ is an estimate of the true variance, which would be unbiased if - # each chain was drawn from the target. c.f. "law of total variance." - sigma_2_plus = w + b_div_n - - return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) - - -# TODO(b/72873233) Move some variant of this to sample_stats. -def _reduce_variance(x, axis=None, biased=True, keepdims=False): - with ops.name_scope("reduce_variance"): - x = ops.convert_to_tensor(x, name="x") - mean = math_ops.reduce_mean(x, axis=axis, keepdims=True) - biased_var = math_ops.reduce_mean( - math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims) - if biased: - return biased_var - n = _axis_size(x, axis) - return (n / (n - 1.)) * biased_var - - -def _axis_size(x, axis=None): - """Get number of elements of `x` in `axis`, as type `x.dtype`.""" - if axis is None: - return math_ops.cast(array_ops.size(x), x.dtype) - return math_ops.cast( - math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype) - - -def _is_list_like(x): - """Helper which returns `True` if input is `list`-like.""" - return isinstance(x, (tuple, list)) - - -def _broadcast_maybelist_arg(states, secondary_arg, name): - """Broadcast a listable secondary_arg to that of states.""" - if _is_list_like(secondary_arg): - if len(secondary_arg) != len(states): - raise ValueError("Argument `%s` was a list of different length ({}) than " - "`states` ({})".format(name, len(states))) - else: - secondary_arg = [secondary_arg] * len(states) - - return secondary_arg diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py deleted file mode 100644 index 7bdeaa862d5bb64fa8940df453c7aa2b66023eda..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.metropolis_hastings_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', -] - -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py deleted file mode 100644 index dc1ac68ce009fa46d6c05a3200a29d9fdf245707..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py +++ /dev/null @@ -1,426 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step. - -@@evolve -@@uniform_random_proposal -@@normal_random_proposal -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops - -__all__ = [ - 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', -] - - -def _single_iteration(current_state, current_log_density, - log_unnormalized_prob_fn, proposal_fn, seed=None, - name='None'): - """Performs a single Metropolis-Hastings step. - - Args: - current_state: Float-like `Tensor` (i.e., `dtype` is either - `tf.float16`, `tf.float32` or `tf.float64`) of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` - callables. - current_log_density: Float-like `Tensor` with `dtype` and shape equivalent - to `log_unnormalized_prob_fn(current_state)`, i.e., matching the result of - `log_unnormalized_prob_fn` invoked at `current_state`. - log_unnormalized_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, where `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair is a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric (e.g., random walk, where the proposal is either - normal or uniform centered at `current_state`), i.e., - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to `None` instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: Python `str` name prefix for ops managed by this function. - - Returns: - next_state: `Tensor` with `dtype` and shape matching `current_state`. - Created by propagating the chain by one step, starting from - `current_state`. - next_log_density: `Tensor` with `dtype` and shape matching - `current_log_density`, which is equal to the value of the unnormalized - `log_unnormalized_prob_fn` computed at `next_state`. - log_accept_ratio: `Tensor` with `dtype` and shape matching - `current_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio used in generating the `next_state`. - """ - - with ops.name_scope(name, 'single_iteration', [current_state]): - # The proposed state and the log of the corresponding Hastings ratio. - proposal_state, log_transit_ratio = proposal_fn(current_state) - - # If the log ratio is None, assume that the transitions are symmetric, - # i.e., Prob(Current -> Proposed) = Prob(Proposed -> Current). - if log_transit_ratio is None: - log_transit_ratio = 0. - - # Log-density of the proposal state. - proposal_log_density = log_unnormalized_prob_fn(proposal_state) - - # Ops to compute the log of the acceptance ratio. Recall that the - # acceptance ratio is: [Prob(Proposed) / Prob(Current)] * - # [Prob(Proposed -> Current) / Prob(Current -> Proposed)]. The log of the - # second term is the log_transit_ratio. - with ops.name_scope('accept_reject'): - # The log of the acceptance ratio. - log_accept_ratio = (proposal_log_density - current_log_density - + log_transit_ratio) - - # A proposal is accepted or rejected depending on the acceptance ratio. - # If the acceptance ratio is greater than 1 then it is always accepted. - # If the acceptance ratio is less than 1 then the proposal is accepted - # with probability = acceptance ratio. As we are working in log space to - # prevent over/underflows, this logic is expressed in log terms below. - # If a proposal is accepted we place a True in the acceptance state - # tensor and if it is to be rejected we place a False. - # The log_draws below have to be compared to the log_accept_ratio so we - # make sure that they have the same data type. - log_draws = math_ops.log(random_ops.random_uniform( - array_ops.shape(current_log_density), seed=seed, - dtype=log_accept_ratio.dtype)) - is_proposal_accepted = log_draws < log_accept_ratio - - # The acceptance state decides which elements of the current state are to - # be replaced with the corresponding elements in the proposal state. - with ops.name_scope(name, 'metropolis_single_step', - [current_state, current_log_density]): - next_log_density = array_ops.where(is_proposal_accepted, - proposal_log_density, - current_log_density) - next_state = array_ops.where(is_proposal_accepted, proposal_state, - current_state) - - return next_state, next_log_density, log_accept_ratio - - -def evolve(initial_sample, - initial_log_density, - initial_log_accept_ratio, - log_unnormalized_prob_fn, - proposal_fn, - n_steps=1, - seed=None, - name=None): - """Performs `n_steps` of the Metropolis-Hastings update. - - Given a probability density function, `f(x)` and a proposal scheme which - generates new points from old, this `Op` returns a tensor - which may be used to generate approximate samples from the target distribution - using the Metropolis-Hastings algorithm. These samples are from a Markov chain - whose equilibrium distribution matches the target distribution. - - The probability distribution may have an unknown normalization constan. - We parameterize the probability density as follows: - ``` - f(x) = exp(L(x) + constant) - ``` - Here `L(x)` is any continuous function with an (possibly unknown but finite) - upper bound, i.e. there exists a number beta such that - `L(x)< beta < infinity` for all x. The constant is the normalization needed - to make `f(x)` a probability density (as opposed to just a finite measure). - - Although `initial_sample` can be arbitrary, a poor choice may result in a - slow-to-mix chain. In many cases the best choice is the one that maximizes - the target density, i.e., choose `initial_sample` such that - `f(initial_sample) >= f(x)` for all `x`. - - - If the support of the distribution is a strict subset of R^n (but of non zero - measure), then the unnormalized log-density `L(x)` should return `-infinity` - outside the support domain. This effectively forces the sampler to only - explore points in the regions of finite support. - - Usage: - This function is meant to be wrapped up with some of the common proposal - schemes (e.g. random walk, Langevin diffusion etc) to produce a more user - friendly interface. However, it may also be used to create bespoke samplers. - - The following example, demonstrates the use to generate a 1000 uniform random - walk Metropolis samplers run in parallel for the normal target distribution. - ```python - n = 3 # dimension of the problem - - # Generate 1000 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = tf.get_variable( - 'state',initializer=tf.random_normal([1000, n], mean=3.0, - dtype=tf.float64, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = tf.get_variable( - 'state_log_density', initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = tf.get_variable( - 'log_acceptance_ratio', initializer=tf.zeros([1000], dtype=tf.float64)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = tf.initialize_all_variables() - with tf.Session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps and print out the mean across - # the chains every 100 iterations. - for n_iter in range(10): - # Executing the stepper advances the chain to the next state. - sess.run(stepper) - # Print out the current value of the mean(sample) for every dimension. - print(np.mean(sess.run(state), 0)) - # Estimated covariance matrix - samples = sess.run(state) - print('') - print(np.cov(samples, rowvar=False)) - ``` - - Args: - initial_sample: A float-like `tf.Variable` of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` - callables. - initial_log_density: Float-like `tf.Variable` with `dtype` and shape - equivalent to `log_unnormalized_prob_fn(initial_sample)`, i.e., matching - the result of `log_unnormalized_prob_fn` invoked at `current_state`. - initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching - `initial_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio after propagating the chain for `n_steps`. - log_unnormalized_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, here `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair should be a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric, i.e. - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to None instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - n_steps: A positive `int` or a scalar `int32` tensor. Sets the number of - iterations of the chain. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - forward_step: an `Op` to step the Markov chain forward for `n_steps`. - """ - - with ops.name_scope(name, 'metropolis_hastings', [initial_sample]): - current_state = initial_sample - current_log_density = initial_log_density - log_accept_ratio = initial_log_accept_ratio - - # Stop condition for the while_loop - def stop_condition(i, _): - return i < n_steps - - def step(i, loop_vars): - """Wrap `_single_iteration` for `while_loop`.""" - state = loop_vars[0] - state_log_density = loop_vars[1] - return i + 1, list(_single_iteration(state, state_log_density, - log_unnormalized_prob_fn, - proposal_fn, seed=seed)) - - loop_vars = [current_state, current_log_density, log_accept_ratio] - # Build an `Op` to evolve the Markov chain for `n_steps` - (_, [end_state, end_log_density, end_log_acceptance]) = ( - control_flow_ops.while_loop( - stop_condition, step, - (0, loop_vars), - parallel_iterations=1, swap_memory=1)) - - forward_step = control_flow_ops.group( - state_ops.assign(current_log_density, end_log_density), - state_ops.assign(current_state, end_state), - state_ops.assign(log_accept_ratio, end_log_acceptance)) - - return forward_step - - -def uniform_random_proposal(step_size=1., - seed=None, - name=None): - """Returns a callable that adds a random uniform tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). It adds a - sample from a random uniform distribution drawn from [-stepsize, stepsize] - to its input. It also returns the log of the ratio of the probability of - moving from the input point to the proposed point, but since this log ratio is - identically equal to 0 (because the probability of drawing a value `x` from - the symmetric uniform distribution is the same as the probability of drawing - `-x`), it simply returns None for the second element of the returned tuple. - - Args: - step_size: A positive `float` or a scalar tensor of real dtype - controlling the scale of the uniform distribution. - If step_size = a, then draws are made uniformly from [-a, a]. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, 'uniform_random_proposal', [step_size]): - def proposal_fn(input_state, name=None): - """Adds a uniform perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') - return input_state + random_ops.random_uniform( - array_ops.shape(input_state), - minval=-step_size, - maxval=step_size, - seed=seed), None - return proposal_fn - - -def normal_random_proposal(scale=1., - seed=None, - name=None): - """Returns a callable that adds a random normal tensor to the input. - - This function returns a callable that accepts one `Tensor` argument of any - shape and a real data type (i.e. `tf.float32` or `tf.float64`). The callable - adds a sample from a normal distribution with the supplied standard deviation - and zero mean to its input argument (called the proposal point). - The callable returns a tuple with the proposal point as the first element. - The second element is identically `None`. It is included so the callable is - compatible with the expected signature of the proposal scheme argument in the - `metropolis_hastings` function. A value of `None` indicates that the - probability of going from the input point to the proposal point is equal to - the probability of going from the proposal point to the input point. - - Args: - scale: A positive `float` or a scalar tensor of any real dtype controlling - the scale of the normal distribution. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: A string that sets the name for this `Op`. - - Returns: - proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. - """ - - with ops.name_scope(name, 'normal_random_proposal', [scale]): - def proposal_fn(input_state, name=None): - """Adds a normal perturbation to the input state. - - Args: - input_state: A `Tensor` of any shape and real dtype. - name: A string that sets the name for this `Op`. - - Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching - `input_state`. - log_transit_ratio: `None`. Proposal is symmetric. - """ - - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') - return input_state + random_ops.random_normal( - array_ops.shape(input_state), - mean=0., - stddev=scale, - seed=seed), None - return proposal_fn diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index 985177e897f443989e466d1a498c461a30aeb5cb..032b859d469ee5039e08e4af4c2f4ebf35c2ff19 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -44,15 +44,13 @@ def expectation_importance_sampler(f, n=None, seed=None, name='expectation_importance_sampler'): - r"""Monte Carlo estimate of `E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]`. + r"""Monte Carlo estimate of \\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\). - With `p(z) := exp{log_p(z)}`, this `Op` returns + With \\(p(z) := exp^{log_p(z)}\\), this `Op` returns - ``` - n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q, - \approx E_q[ f(Z) p(Z) / q(Z) ] - = E_p[f(Z)] - ``` + \\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,\\) + \\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\) + \\(= E_p[f(Z)]\\) This integral is done in log-space with max-subtraction to better handle the often extreme values that `f(z) p(z) / q(z)` can take on. @@ -95,9 +93,9 @@ def expectation_importance_sampler(f, log_values = log_f_z + log_p_z - q_log_prob_z return _logspace_mean(log_values) - # With f_plus(z) = max(0, f(z)), f_minus(z) = max(0, -f(z)), - # E_p[f(Z)] = E_p[f_plus(Z)] - E_p[f_minus(Z)] - # = E_p[f_plus(Z) + 1] - E_p[f_minus(Z) + 1] + # With \\(f_{plus}(z) = max(0, f(z)), f_{minus}(z) = max(0, -f(z))\\), + # \\(E_p[f(Z)] = E_p[f_{plus}(Z)] - E_p[f_{minus}(Z)]\\) + # \\( = E_p[f_{plus}(Z) + 1] - E_p[f_{minus}(Z) + 1]\\) # Without incurring bias, 1 is added to each to prevent zeros in logspace. # The logarithm is approximately linear around 1 + epsilon, so this is good # for small values of 'z' as well. @@ -121,14 +119,12 @@ def expectation_importance_sampler_logspace( name='expectation_importance_sampler_logspace'): r"""Importance sampling with a positive function, in log-space. - With `p(z) := exp{log_p(z)}`, and `f(z) = exp{log_f(z)}`, this `Op` - returns + With \\(p(z) := exp^{log_p(z)}\\), and \\(f(z) = exp{log_f(z)}\\), + this `Op` returns - ``` - Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q, - \approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ] - = Log[E_p[f(Z)]] - ``` + \\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,\\) + \\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\) + \\(= Log[E_p[f(Z)]]\\) This integral is done in log-space with max-subtraction to better handle the often extreme values that `f(z) p(z) / q(z)` can take on. @@ -196,13 +192,11 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of `E_p[f(X)]`. + """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., - ```none - E_p[f(X)] approx= m**-1 sum_i^m f(x_j), x_j ~iid p(X) - ``` + \\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j), x_j\ ~iid\ p(X)\\) where: @@ -216,8 +210,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, parameterless distribution (e.g., `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and expectation, i.e., - `grad[ Avg{ s_i : i=1...n } ] = Avg{ grad[s_i] : i=1...n }` where - `S_n = Avg{s_i}` and `s_i = f(x_i), x_i ~ p`. + grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n } where + S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\). However, if p is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of non-reparameterized distributions. @@ -296,7 +290,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Args: f: Python callable which can return `f(samples)`. samples: `Tensor` of samples used to form the Monte-Carlo approximation of - `E_p[f(X)]`. A batch of samples should be indexed by `axis` dimensions. + \\(E_p[f(X)]\\). A batch of samples should be indexed by `axis` + dimensions. log_prob: Python callable which can return `log_prob(samples)`. Must correspond to the natural-logarithm of the pdf/pmf of each sample. Only required/used if `use_reparametrization=False`. @@ -316,7 +311,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Returns: approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation - of `E_p[f(X)]`. + of \\(E_p[f(X)]\\). Raises: ValueError: if `f` is not a Python `callable`. @@ -328,7 +323,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, if not callable(f): raise ValueError('`f` must be a callable function.') if use_reparametrization: - return math_ops.reduce_mean(f(samples), axis=axis, keep_dims=keep_dims) + return math_ops.reduce_mean(f(samples), axis=axis, keepdims=keep_dims) else: if not callable(log_prob): raise ValueError('`log_prob` must be a callable function.') @@ -348,7 +343,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, # "Is there a floating point value of x, for which x-x == 0 is false?" # http://stackoverflow.com/q/2686644 fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx). - return math_ops.reduce_mean(fx, axis=axis, keep_dims=keep_dims) + return math_ops.reduce_mean(fx, axis=axis, keepdims=keep_dims) def _sample_mean(values): diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py deleted file mode 100644 index 7786656398e3c87704227be95b3cd23a38785249..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An optimizer module for stochastic gradient Langevin dynamics.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import training_ops - - -class SGLDOptimizer(optimizer.Optimizer): - """An optimizer module for stochastic gradient Langevin dynamics. - - This implements the preconditioned Stochastic Gradient Langevin Dynamics - optimizer [1]. The optimization variable is regarded as a sample from the - posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in - each dimension according to RMSProp [2]. - - Note: If a prior is included in the loss, it should be scaled by - `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches - in the data. I.e., it should be divided by the `num_pseudo_batches` term - described below. - - [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural - Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. - ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 - [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf - - Args: - learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the - optimizer. Must be tuned to the specific function being minimized. - preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential - decay rate of the rescaling of the preconditioner (RMSprop). (This is - "alpha" in [1]). Should be smaller than but nearly `1` to approximate - sampling from the posterior. (Default: `0.95`) - num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of - minibatches in the data set. Trades off noise and prior with the SGD - likelihood term. Note: Assumes the loss is taken as the mean over a - minibatch. Otherwise if the sum was taken, divide this number by the - batch size. (Default: `1`) - burnin: Scalar `int`-like `Tensor`. The number of iterations to collect - gradient statistics to update the preconditioner before starting to draw - noisy samples. (Default: `25`) - diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of - the preconditioner to prevent the preconditioner from degenerating. - (Default: `1e-8`) - name: Python `str` describing ops managed by this function. - (Default: `"SGLDOptimizer"`) - variable_scope: Variable scope used for calls to `tf.get_variable`. - If `None`, a new variable scope is created using name - `ops.get_default_graph().unique_name(name or default_name)`. - - Raises: - InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in - `(0,1]`. - """ - - def __init__(self, - learning_rate, - preconditioner_decay_rate=0.95, - num_pseudo_batches=1, - burnin=25, - diagonal_bias=1e-8, - name=None, - variable_scope=None): - default_name = 'SGLDOptimizer' - with ops.name_scope(name, default_name, [ - learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, - diagonal_bias - ]): - if variable_scope is None: - var_scope_name = ops.get_default_graph().unique_name( - name or default_name) - with varscope_ops.variable_scope(var_scope_name) as scope: - self._variable_scope = scope - else: - self._variable_scope = variable_scope - - self._preconditioner_decay_rate = ops.convert_to_tensor( - preconditioner_decay_rate, name='preconditioner_decay_rate') - self._num_pseudo_batches = ops.convert_to_tensor( - num_pseudo_batches, name='num_pseudo_batches') - self._burnin = ops.convert_to_tensor(burnin, name='burnin') - self._diagonal_bias = ops.convert_to_tensor( - diagonal_bias, name='diagonal_bias') - self._learning_rate = ops.convert_to_tensor( - learning_rate, name='learning_rate') - - with varscope_ops.variable_scope(self._variable_scope): - self._counter = varscope_ops.get_variable( - 'counter', initializer=0, trainable=False) - - self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._preconditioner_decay_rate, - message='`preconditioner_decay_rate` must be non-negative'), - check_ops.assert_less_equal( - self._preconditioner_decay_rate, - 1., - message='`preconditioner_decay_rate` must be at most 1.'), - ], self._preconditioner_decay_rate) - - self._num_pseudo_batches = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._num_pseudo_batches, - 0, - message='`num_pseudo_batches` must be greater than zero') - ], self._num_pseudo_batches) - - self._burnin = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin, message='`burnin` must be non-negative'), - check_ops.assert_integer( - self._burnin, message='`burnin` must be an integer') - ], self._burnin) - - self._diagonal_bias = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._diagonal_bias, - message='`diagonal_bias` must be non-negative') - ], self._diagonal_bias) - - super(SGLDOptimizer, self).__init__(use_locking=False, - name=name or default_name) - - def _create_slots(self, var_list): - for v in var_list: - init_rms = init_ops.ones_initializer(dtype=v.dtype) - self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), - v.dtype, 'rms', self._name) - - def _prepare(self): - # We need to put the conversion and check here because a user will likely - # want to decay the learning rate dynamically. - self._learning_rate_tensor = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._learning_rate, message='`learning_rate` must be non-negative') - ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) - self._decay_tensor = ops.convert_to_tensor( - self._preconditioner_decay_rate, name='preconditioner_decay_rate') - - super(SGLDOptimizer, self)._prepare() - - def _apply_dense(self, grad, var): - rms = self.get_slot(var, 'rms') - - with ops.control_dependencies([ - self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, - var.dtype.base_dtype))]): - new_grad = self._apply_noisy_update(rms, grad) - - return training_ops.apply_gradient_descent( - var, - math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, - use_locking=self._use_locking).op - - def _apply_sparse(self, grad, var): - rms = self.get_slot(var, 'rms') - - with ops.control_dependencies([ - self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, - var.dtype.base_dtype))]): - new_grad = self._apply_noisy_update(rms, grad) - - return training_ops.apply_gradient_descent( - var, - math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), - new_grad, - use_locking=self._use_locking).op - - def _finish(self, update_ops, name_scope): - update_ops.append([self._counter.assign_add(1)]) - return control_flow_ops.group(*update_ops, name=name_scope) - - @property - def variable_scope(self): - """Variable scope of all calls to `tf.get_variable`.""" - return self._variable_scope - - def _apply_noisy_update(self, mom, grad): - # Compute and apply the gradient update following - # preconditioned Langevin dynamics - stddev = array_ops.where( - array_ops.squeeze(self._counter > self._burnin), - math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), - array_ops.zeros([], grad.dtype)) - - preconditioner = math_ops.rsqrt( - mom + math_ops.cast(self._diagonal_bias, grad.dtype)) - return ( - 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, - grad.dtype) + - random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * - stddev * math_ops.sqrt(preconditioner)) - - def _update_momentum(self, mom, grad, decay): - # Keep an exponentially weighted moving average of squared gradients. - # Not thread safe - return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py deleted file mode 100644 index ca3d75b5bfee093449026c7d1d62e3bdeff6b096..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions related to managing `tf.Variable`s. - -@@externalize_variables_as_args -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -from tensorflow.python.framework import ops -from tensorflow.python.ops import gradients_impl as gradients_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.ops import variables as variables_ops - -__all__ = [ - "externalize_variables_as_args", -] - - -# Cause all warnings to always be triggered. -# Not having this means subsequent calls wont trigger the warning. -warnings.simplefilter("always") - - -def externalize_variables_as_args(fn, - fn_args=(), - ancestor_variables=None, - possible_ancestor_vars=None, - assert_variable_override=False, - name=None): - """"Converts variables within a callable into explicit args. - - Makes a new callable from `fn` which has arguments `list(fn_args) + - list(ancestor_variables)`. If `ancestor_variables` is not specified, it is - inferred by checking which of `possible_ancestor_vars` actually influences the - return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`). - By default `possible_ancestor_vars` is `tf.trainable_variables() + - tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`. - - #### Examples: - - ```python - num_samples = 2 - num_dims = 1 - dtype = np.float32 - - def foo(x): - x = tf.convert_to_tensor(x, dtype=dtype, name="x") - s = x.shape.as_list() - y = tf.get_variable( - name="y", - dtype=dtype, - initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) - return x + y - - x = tf.constant(dtype([0.1, 0.2])) - - wrapped_foo, discovered_ancestor_variables = ( - externalize_variables_as_args(foo, [x])) - - new_x = dtype([[1.], [2.]]) - new_y = dtype([[3.], [4.]]) - new_result = wrapped_foo(new_x, new_y) - # ==> [[4.], [6.]] - - discovered_ancestor_variables == [tf.get_variable("y", dtype)] - # ==> [True] - ``` - - Args: - fn: Python callable which returns a `Tensor` and accepts `*fn_args`. - fn_args: Python list of args to `fn`. Represents dummy arguments passed to - `fn` to trace its execution; actual values are unimportant. These args are - only used to construct the output of `fn` and to resolve the ancestor - `tf.Variable`s. - Default value: `()` (i.e., `fn` takes no args). - ancestor_variables: Python list of `tf.Variable`s. When `None` the list is - expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing - the `ancestor_variables` the internal call to `fn` is avoided. - Default value: `None` (i.e., `tf.Variable` dependencies are discovered). - possible_ancestor_vars: Python list of possible `tf.Variable`s which might - be a dependency of computing `fn(*fn_args)`. - Default value: `None` (i.e., expanded as described above). - assert_variable_override: Python `bool` indicating that not finding a - `tf.Variable` in the override list is an exception. - Default value: `False` (i.e., missing a `Variable` triggers a `warning`). - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "externalize_variables_as_args"). - - Returns: - wrapped_fn: Python callable taking arguments like - `*(list(fn_args) + discovered_ancestor_variables)`. - discovered_ancestor_variables: Python list of `tf.Variable`s known to be a - dependency of `fn(*fn_args)`. - - Raises: - ValueError: if `assert_variable_override` is `True` and `Variable` is - requested but not overridden. - """ - def _make_bypassing_custom_getter_fn(new_var_dict): - """Return dict value rather than what would otherwise be dict key.""" - def _custom_getter(getter, *args, **kwargs): - v = getter(*args, **kwargs) - new_v = new_var_dict.get(v, None) - if new_v is None: - msg = "Variable \"{}\" not found in bypass dict.".format(v) - if assert_variable_override: - raise ValueError(msg) - warnings.warn(msg) - return v - return new_v - return _custom_getter - - with ops.name_scope(name, "externalize_variables_as_args"): - if ancestor_variables is not None and not ancestor_variables: - return fn, () - if ancestor_variables is None: - y = fn(*fn_args) # Side-effect: adds trainable vars. - if possible_ancestor_vars is None: - possible_ancestor_vars = ( - variables_ops.trainable_variables() + - ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) - # TODO(b/72873296): Add a dedicated op for identifying ancestors. - ancestors = [v for g, v - in zip(gradients_ops.gradients(y, possible_ancestor_vars), - possible_ancestor_vars) - if g is not None] - ancestor_variables = sorted(ancestors, key=lambda v: v.name) - n = len(fn_args) - def _fn(*args): - with ops.name_scope("wrapped_fn"): - vars_dict = dict( - (k, ops.convert_to_tensor( - v, dtype=k.dtype.base_dtype, name=k.op.name)) - for k, v in zip(ancestor_variables, args[n:])) - with varscope_ops.variable_scope( - varscope_ops.get_variable_scope(), - reuse=True, - custom_getter=_make_bypassing_custom_getter_fn(vars_dict)): - return fn(*args[:n]) - return _fn, ancestor_variables diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py deleted file mode 100644 index 4d5f0cfe9713a011b32c5aba8d429847d81f33e2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""An optimizer module for constant stochastic gradient descent.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope as varscope_ops -from tensorflow.python.training import optimizer -from tensorflow.python.training import training_ops - - -class VariationalSGDOptimizer(optimizer.Optimizer): - """An optimizer module for constant stochastic gradient descent. - - This implements an optimizer module for the constant stochastic gradient - descent algorithm [1]. The optimization variable is regarded as an - approximate sample from the posterior . - - Note: If a prior is included in the loss, it should be scaled by - `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches - in the data. I.e., it should be divided by the `num_pseudo_batches` term - described below. - - [1]: "Stochastic Gradient Descent as Approximate Bayesian Inference - Stephan Mandt, Matthew D. Hoffman, David M. Blei. - ArXiv:1704.04289, 2017. https://arxiv.org/abs/1704.04289 - - Args: - batch_size: Scalar `int`-like `Tensor`. The number of examples in a - minibatch in the data set. Note: Assumes the loss is taken as the mean - over a minibatch. Otherwise if the sum was taken set this to 1. - total_num_examples: Scalar `int`-like `Tensor`. The total number of examples - in the data set. - max_learning_rate: Scalar `float`-like `Tensor`. A maximum allowable - effective coordinate-wise learning rate. The algorithm scales down any - effective learning rate (i.e. after preconditioning) that is larger than - this. (Default: `1`) - preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential - decay rate of the rescaling of the preconditioner (RMSprop). (This is - "alpha" in [1]). Should be smaller than but nearly `1` to approximate - sampling from the posterior. (Default: `0.95`) - burnin: Scalar `int`-like `Tensor`. The number of iterations to collect - gradient statistics to update the preconditioner before starting to draw - noisy samples. (Default: `25`) - burnin_max_learning_rate: Scalar `float`-like `Tensor`. Maximum learning - rate to use during the burnin period. - (Default: `1e-8`) - use_single_learning_rate: Boolean Indicates whether one single learning - rate is used or coordinate_wise learning rates are used. - (Default: `False`) - name: Python `str` describing ops managed by this function. - (Default: `"VariationalSGDOptimizer"`) - variable_scope: Variable scope used for calls to `tf.get_variable`. - If `None`, a new variable scope is created using name - `ops.get_default_graph().unique_name(name or default_name)`. - - Raises: - InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in - `(0,1]`. - """ - - def __init__(self, - batch_size, - total_num_examples, - max_learning_rate=1.0, - preconditioner_decay_rate=0.95, - burnin=25, - burnin_max_learning_rate=1e-6, - use_single_learning_rate=False, - name=None, - variable_scope=None): - default_name = 'VariationalSGDOptimizer' - with ops.name_scope(name, default_name, [ - max_learning_rate, preconditioner_decay_rate, batch_size, burnin, - burnin_max_learning_rate - ]): - if variable_scope is None: - var_scope_name = ops.get_default_graph().unique_name( - name or default_name) - with varscope_ops.variable_scope(var_scope_name) as scope: - self._variable_scope = scope - else: - self._variable_scope = variable_scope - - self._preconditioner_decay_rate = ops.convert_to_tensor( - preconditioner_decay_rate, name='preconditioner_decay_rate') - self._batch_size = ops.convert_to_tensor(batch_size, name='batch_size') - self._total_num_examples = ops.convert_to_tensor( - total_num_examples, name='total_num_examples') - self._burnin = ops.convert_to_tensor(burnin, name='burnin') - self._burnin_max_learning_rate = ops.convert_to_tensor( - burnin_max_learning_rate, name='burnin_max_learning_rate') - self._max_learning_rate = ops.convert_to_tensor( - max_learning_rate, name='max_learning_rate') - self._use_single_learning_rate = use_single_learning_rate - - with varscope_ops.variable_scope(self._variable_scope): - self._counter = varscope_ops.get_variable( - 'counter', initializer=0, trainable=False) - - self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._preconditioner_decay_rate, - message='`preconditioner_decay_rate` must be non-negative'), - check_ops.assert_less_equal( - self._preconditioner_decay_rate, - 1., - message='`preconditioner_decay_rate` must be at most 1.'), - ], self._preconditioner_decay_rate) - - self._batch_size = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._batch_size, - 0, - message='`batch_size` must be greater than zero') - ], self._batch_size) - - self._total_num_examples = control_flow_ops.with_dependencies([ - check_ops.assert_greater( - self._total_num_examples, - 0, - message='`total_num_examples` must be greater than zero') - ], self._total_num_examples) - - self._burnin = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin, message='`burnin` must be non-negative'), - check_ops.assert_integer( - self._burnin, message='`burnin` must be an integer') - ], self._burnin) - - self._burnin_max_learning_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._burnin_max_learning_rate, - message='`burnin_max_learning_rate` must be non-negative') - ], self._burnin_max_learning_rate) - - self._max_learning_rate = control_flow_ops.with_dependencies([ - check_ops.assert_non_negative( - self._max_learning_rate, - message='`max_learning_rate` must be non-negative') - ], self._max_learning_rate) - - super(VariationalSGDOptimizer, self).__init__( - use_locking=False, name=name or default_name) - - def _create_slots(self, var_list): - for v in var_list: - init_moment = init_ops.zeros_initializer(dtype=v.dtype) - self._get_or_make_slot_with_initializer( - v, init_moment, v.get_shape(), v.dtype, 'first_moment', self._name) - self._get_or_make_slot_with_initializer( - v, init_moment, v.get_shape(), v.dtype, 'second_moment', self._name) - - def _prepare(self): - self._decay_tensor = ops.convert_to_tensor( - self._preconditioner_decay_rate, name='preconditioner_decay_rate') - self._batch_size_tensor = ops.convert_to_tensor( - self._batch_size, name='batch_size_tensor') - - super(VariationalSGDOptimizer, self)._prepare() - - def _get_coordinatewise_learning_rate(self, grad, var): - # Compute the learning rate using a moving average for the diagonal of BB^T - avg_first = self.get_slot(var, 'first_moment') - avg_second = self.get_slot(var, 'second_moment') - decay_tensor = math_ops.cast(self._decay_tensor, var.dtype) - batch_size = math_ops.cast(self._batch_size_tensor, var.dtype) - - # Create an estimator for the moving average of gradient mean and variance - # via Welford's algorithm - if isinstance(grad, ops.Tensor): - delta = grad - avg_first - first_moment_update = avg_first.assign_add( - array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype), - 1. - decay_tensor) * delta) - - with ops.control_dependencies([first_moment_update]): - second_moment_update = avg_second.assign_add( - math_ops.cast(self._counter < 1, var.dtype) * - -(1. - decay_tensor) * ( - avg_second - decay_tensor * math_ops.square(delta))) - diag_preconditioner = control_flow_ops.with_dependencies( - [second_moment_update], - clip_ops.clip_by_value(avg_second, 1e-12, 1e12)) - elif isinstance(grad, ops.IndexedSlices): - delta = grad.values - array_ops.gather_nd(avg_first, grad.indices) - first_moment_update = state_ops.scatter_add( - avg_first, - grad.indices, - array_ops.where(self._counter < 1, - math_ops.cast(1., var.dtype), - 1. - decay_tensor) * delta) - - with ops.control_dependencies([first_moment_update]): - avg_second = state_ops.scatter_add( - avg_second, - grad.indices, - math_ops.cast(self._counter < 1, var.dtype) * - -(1. - decay_tensor) * ( - array_ops.gather_nd(avg_second, grad.indices) - decay_tensor * - math_ops.square(delta))) - avg_second = array_ops.gather_nd(avg_second, grad.indices) - # TODO(b/70783772) - diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12) - else: - raise errors.InvalidArgumentError( - None, None, 'grad must of type Tensor or IndexedSlice') - - diag_preconditioner *= batch_size - - if self._use_single_learning_rate: - diag_preconditioner = math_ops.reduce_mean(diag_preconditioner) - - # From Theorem 2 Corollary 1 of Mandt et al. 2017 - return 2. * batch_size / ( - math_ops.cast(self._total_num_examples, var.dtype.base_dtype) * - diag_preconditioner) - - def _apply_dense(self, grad, var): - - max_learning_rate = array_ops.where(self._counter < self._burnin, - self._burnin_max_learning_rate, - self._max_learning_rate) - - learn_rates = clip_ops.clip_by_value( - self._get_coordinatewise_learning_rate(grad, var), 0.0, - math_ops.cast(max_learning_rate, var.dtype.base_dtype)) - - newgrad = grad * learn_rates - return training_ops.apply_gradient_descent( - var, - math_ops.cast(1.0, var.dtype), - newgrad, - use_locking=self._use_locking).op - - def _apply_sparse(self, grad, var): - - max_learning_rate = array_ops.where(self._counter < self._burnin, - self._burnin_max_learning_rate, - self._max_learning_rate) - - learn_rate = clip_ops.clip_by_value( - self._get_coordinatewise_learning_rate(grad, var), 0.0, - math_ops.cast(max_learning_rate, var.dtype)) - delta = grad.values * learn_rate - - return state_ops.scatter_sub(var, grad.indices, delta, - use_locking=self._use_locking) - - def _finish(self, update_ops, name_scope): - update_ops.append([self._counter.assign_add(1)]) - return control_flow_ops.group(*update_ops, name=name_scope) - - @property - def variable_scope(self): - """Variable scope of all calls to `tf.get_variable`.""" - return self._variable_scope diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 6fdcd0f996ee011842a5add79f06264a28a2145c..8eac1243ef63dd09c5c5dad4bcd9bd7a15f58900 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -14,15 +14,6 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = ["**/OWNERS"], - ), - visibility = ["//tensorflow:__subpackages__"], -) - package_group(name = "friends") cc_library( @@ -128,7 +119,7 @@ py_library( py_test( name = "gbdt_batch_test", - size = "small", + size = "medium", srcs = ["python/training/functions/gbdt_batch_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 289f5bb3140974d8c37f4938ceef27275b099f9a..8cff1a3bb1d11aff6a264636291a7149b40de516 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -10,23 +10,17 @@ package( load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", + deps = [ + "custom_export_strategy", + ":custom_loss_head", + ":estimator", + ":model", + ":trainer_hooks", + ], ) py_library( @@ -34,12 +28,13 @@ py_library( srcs = ["model.py"], srcs_version = "PY2AND3", deps = [ + ":estimator_utils", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", "//tensorflow/python:framework_ops", "//tensorflow/python:state_ops", - "//tensorflow/python:training", + "//tensorflow/python:training_util", ], ) @@ -57,6 +52,18 @@ py_library( ], ) +py_library( + name = "estimator_utils", + srcs = ["estimator_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + ], +) + py_test( name = "trainer_hooks_test", size = "small", @@ -124,6 +131,7 @@ py_library( srcs = ["estimator.py"], srcs_version = "PY2AND3", deps = [ + ":estimator_utils", ":model", "//tensorflow/contrib/boosted_trees:losses", "//tensorflow/contrib/learn", @@ -136,6 +144,7 @@ py_library( srcs = ["dnn_tree_combined_estimator.py"], srcs_version = "PY2AND3", deps = [ + ":estimator_utils", ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees:model_ops_py", @@ -149,7 +158,7 @@ py_library( py_test( name = "dnn_tree_combined_estimator_test", - size = "small", + size = "medium", srcs = ["dnn_tree_combined_estimator_test.py"], srcs_version = "PY2AND3", tags = [ @@ -165,3 +174,22 @@ py_test( "//tensorflow/python:framework_for_generated_wrappers", ], ) + +py_test( + name = "estimator_test", + size = "medium", + srcs = ["estimator_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", + "notsan", + ], + deps = [ + ":estimator", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 31f5c444817b9b82723c86bea3504d4934e57eb8..62f1f4122b05b56a708823df4246d618bd3fa5d4 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -39,7 +39,8 @@ _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" def make_custom_export_strategy(name, convert_fn, feature_columns, - export_input_fn): + export_input_fn, + use_core_columns=False): """Makes custom exporter of GTFlow tree format. Args: @@ -54,11 +55,11 @@ def make_custom_export_strategy(name, An `ExportStrategy`. """ base_strategy = saved_model_export_utils.make_export_strategy( - serving_input_fn=export_input_fn) + serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features( - input_fn.features, feature_columns) + input_fn.features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" @@ -93,7 +94,9 @@ def make_custom_export_strategy(name, "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir - return export_strategy.ExportStrategy(name, export_fn) + + return export_strategy.ExportStrategy( + name, export_fn, strip_default_attrs=True) def convert_to_universal_format(dtec, sorted_feature_names, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index cec3892b57655dc967b4e7926f7f5a6a30084487..9994c84ebdb930eea0818188225488eb5eca84eb 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -19,14 +19,13 @@ logits of the DNN. The input layer of the DNN (including the embeddings learned over sparse features) can optionally be provided to the boosted trees as an additional input feature. """ - from __future__ import absolute_import from __future__ import division from __future__ import print_function import six - from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch @@ -34,6 +33,7 @@ from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import nn @@ -43,10 +43,8 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary from tensorflow.python.training import training_util - _DNN_LEARNING_RATE = 0.001 - def _get_optimizer(optimizer): if callable(optimizer): return optimizer() @@ -59,16 +57,25 @@ def _add_hidden_layer_summary(value, tag): summary.histogram("%s_activation" % tag, value) -def _dnn_tree_combined_model_fn( - features, labels, mode, head, dnn_hidden_units, - dnn_feature_columns, tree_learner_config, num_trees, - tree_examples_per_layer, - config=None, dnn_optimizer="Adagrad", - dnn_activation_fn=nn.relu, dnn_dropout=None, - dnn_input_layer_partitioner=None, - dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, - tree_feature_columns=None, - tree_center_bias=True): +def _dnn_tree_combined_model_fn(features, + labels, + mode, + head, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + config=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=False, + use_core_versions=False): """DNN and GBDT combined model_fn. Args: @@ -106,6 +113,8 @@ def _dnn_tree_combined_model_fn( set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. Returns: A `ModelFnOps` object. @@ -135,11 +144,17 @@ def _dnn_tree_combined_model_fn( "input_from_feature_columns", values=tuple(six.itervalues(features)), partitioner=dnn_partitioner) as input_layer_scope: - input_layer = layers.input_from_feature_columns( - columns_to_tensors=features, - feature_columns=dnn_feature_columns, - weight_collections=[dnn_parent_scope], - scope=input_layer_scope) + if use_core_versions: + input_layer = feature_column_lib.input_layer( + features=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope]) + else: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) previous_layer = input_layer for layer_id, num_hidden_units in enumerate(dnn_hidden_units): with variable_scope.variable_scope( @@ -222,24 +237,51 @@ def _dnn_tree_combined_model_fn( del loss return control_flow_ops.no_op() - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_no_train_op_fn, - logits=tree_train_logits) - dnn_train_op = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_dnn_train_op_fn, - logits=dnn_logits).train_op - tree_train_op = head.create_model_fn_ops( - features=tree_features, - mode=mode, - labels=labels, - train_op_fn=_tree_train_op_fn, - logits=tree_train_logits).train_op + if use_core_versions: + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits) + dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + dnn_train_op).train_op + + tree_train_op = head.create_estimator_spec( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits) + tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops( + tree_train_op).train_op + + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op if tree_center_bias: num_trees += 1 @@ -277,7 +319,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedClassifier instance. Args: @@ -322,6 +365,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.multi_class_head( n_classes=n_classes, @@ -336,8 +381,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, @@ -366,7 +411,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedRegressor instance. Args: @@ -411,6 +457,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.regression_head( label_name=label_name, @@ -430,7 +478,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedRegressor, self).__init__( model_fn=_model_fn, model_dir=model_dir, @@ -460,7 +509,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, tree_feature_columns=None, - tree_center_bias=True): + tree_center_bias=False, + use_core_versions=False): """Initializes a DNNBoostedTreeCombinedEstimator instance. Args: @@ -500,6 +550,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): set to True, these features are in addition to dnn_feature_columns. tree_center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_versions: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ def _model_fn(features, labels, mode, config): return _dnn_tree_combined_model_fn( @@ -507,8 +559,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator): tree_learner_config, num_trees, tree_examples_per_layer, config, dnn_optimizer, dnn_activation_fn, dnn_dropout, dnn_input_layer_partitioner, dnn_input_layer_to_tree, - dnn_steps_to_train, - tree_feature_columns, tree_center_bias) + dnn_steps_to_train, tree_feature_columns, tree_center_bias, + use_core_versions) super(DNNBoostedTreeCombinedEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 83d58c561008e8a5a69eb503d1605bb9e940f281..f495edc62f0909880c170ccb4cf5d11e3f20f55c 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -19,15 +19,17 @@ from __future__ import division from __future__ import print_function import tempfile - from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest @@ -100,6 +102,35 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): classifier.fit(input_fn=_train_input_fn, steps=15) classifier.evaluate(input_fn=_eval_input_fn, steps=1) + def testFitAndEvaluateDontThrowExceptionWithCore(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + # Use core head + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + classifier = estimator.DNNBoostedTreeCombinedEstimator( + head=head_fn, + dnn_hidden_units=[1], + # Use core feature columns + dnn_feature_columns=[core_feature_column.numeric_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=True, + tree_feature_columns=[], + use_core_versions=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 01752416b347dd0a5e646283b6b5572592df4690..89d0d611d2905492cec09e033b8cbc238ec7fac6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -40,7 +40,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): label_keys=None, feature_engineering_fn=None, logits_modifier_function=None, - center_bias=True): + center_bias=True, + use_core_libs=False): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -63,7 +64,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. - + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. Raises: ValueError: If learner_config is not valid. """ @@ -81,7 +83,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): n_classes=n_classes, weight_column_name=weight_column_name, enable_centered_bias=False, - loss_fn=loss_fn) + loss_fn=loss_fn, + label_keys=label_keys) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: @@ -98,6 +101,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'examples_per_layer': examples_per_layer, 'center_bias': center_bias, 'logits_modifier_function': logits_modifier_function, + 'use_core_libs': use_core_libs, }, model_dir=model_dir, config=config, @@ -119,7 +123,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): config=None, feature_engineering_fn=None, logits_modifier_function=None, - center_bias=True): + center_bias=True, + use_core_libs=False): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -144,6 +149,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ head = head_lib.regression_head( label_name=label_name, @@ -165,6 +172,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'examples_per_layer': examples_per_layer, 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, + 'use_core_libs': use_core_libs, }, model_dir=model_dir, config=config, @@ -188,7 +196,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): config=None, feature_engineering_fn=None, logits_modifier_function=None, - center_bias=True): + center_bias=True, + use_core_libs=False): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -209,6 +218,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. + use_core_libs: Whether feature columns and loss are from the core (as + opposed to contrib) version of tensorflow. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -221,6 +232,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'examples_per_layer': examples_per_layer, 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, + 'use_core_libs': use_core_libs, }, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0d58317bd59331cfcde0e12aeb3a3a03fc45d89b --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GBDT estimator.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +from tensorflow.contrib.boosted_trees.estimator_batch import estimator +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column +from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.feature_column import feature_column_lib as core_feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest + + +def _train_input_fn(): + features = {"x": constant_op.constant([[2.], [1.], [1.]])} + label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + return features, label + + +def _eval_input_fn(): + features = {"x": constant_op.constant([[1.], [2.], [2.]])} + label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) + return features, label + + +class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._export_dir_base = tempfile.mkdtemp() + "export/" + gfile.MkDir(self._export_dir_base) + + def testFitAndEvaluateDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + + def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + # Use core head + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + + model = estimator.GradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")], + use_core_libs=True) + + model.fit(input_fn=_train_input_fn, steps=15) + model.evaluate(input_fn=_eval_input_fn, steps=1) + model.export(self._export_dir_base) + + def testFitAndEvaluateDontThrowExceptionWithCoreForClassifier(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")], + use_core_libs=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + classifier.export(self._export_dir_base) + + def testFitAndEvaluateDontThrowExceptionWithCoreForRegressor(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + regressor = estimator.GradientBoostedDecisionTreeRegressor( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[core_feature_column.numeric_column("x")], + use_core_libs=True) + + regressor.fit(input_fn=_train_input_fn, steps=15) + regressor.evaluate(input_fn=_eval_input_fn, steps=1) + regressor.export(self._export_dir_base) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48a7f85eada8c72de83b814af2f00e97a62a073e --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for converting between core and contrib feature columns.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output + +_CORE_MODE_TO_CONTRIB_MODE_ = { + model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN, + model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER +} + + +def _core_mode_to_contrib_mode(mode): + return _CORE_MODE_TO_CONTRIB_MODE_[mode] + + +def _export_outputs_to_output_alternatives(export_outputs): + """Converts EstimatorSpec.export_outputs to output_alternatives. + + Args: + export_outputs: export_outputs created by create_estimator_spec. + Returns: + converted output_alternatives. + """ + output = dict() + if export_outputs is not None: + for key, value in export_outputs.items(): + if isinstance(value, export_output.ClassificationOutput): + exported_predictions = { + prediction_key.PredictionKey.SCORES: value.scores, + prediction_key.PredictionKey.CLASSES: value.classes + } + output[key] = (constants.ProblemType.CLASSIFICATION, + exported_predictions) + return output + return None + + +def estimator_spec_to_model_fn_ops(estimator_spec, export_alternatives=False): + if export_alternatives: + alternatives = _export_outputs_to_output_alternatives( + estimator_spec.export_outputs) + else: + alternatives = [] + + return model_fn.ModelFnOps( + mode=_core_mode_to_contrib_mode(estimator_spec.mode), + predictions=estimator_spec.predictions, + loss=estimator_spec.loss, + train_op=estimator_spec.train_op, + eval_metric_ops=estimator_spec.eval_metric_ops, + output_alternatives=alternatives) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index c6455a7ea3d18eb358edee034cee58b2bed21024..15ab6d814522ab1dee58dcd71246354fc4d8a483 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -20,6 +20,7 @@ from __future__ import print_function import copy +from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch @@ -60,6 +61,7 @@ def model_builder(features, labels, mode, params, config): feature_columns = params["feature_columns"] weight_column_name = params["weight_column_name"] num_trees = params["num_trees"] + use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] if features is None: raise ValueError("At least one feature must be specified.") @@ -93,7 +95,8 @@ def model_builder(features, labels, mode, params, config): learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, - features=training_features) + features=training_features, + use_core_columns=use_core_libs) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -108,12 +111,22 @@ def model_builder(features, labels, mode, params, config): update_op = state_ops.assign_add(global_step, 1).op return update_op - model_fn_ops = head.create_model_fn_ops( - features=features, - mode=mode, - labels=labels, - train_op_fn=_train_op_fn, - logits=logits) + create_estimator_spec_op = getattr(head, "create_estimator_spec", None) + if use_core_libs and callable(create_estimator_spec_op): + model_fn_ops = head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops) + else: + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) if num_trees: if center_bias: num_trees += 1 diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 754b7bc3270d647fc381033b769eadd7b791771e..3bf33186ec13f5ff991db938d59849c0124a30a0 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -137,6 +137,61 @@ class TreeEnsembleDeserializeOp : public OpKernel { } }; +class TreeEnsembleUsedHandlersOp : public OpKernel { + public: + explicit TreeEnsembleUsedHandlersOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("num_all_handlers", &num_handlers_)); + } + + void Compute(OpKernelContext* context) override { + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; + + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); + + // Get the stamp token. + const Tensor* stamp_token_t; + OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); + int64 stamp_token = stamp_token_t->scalar()(); + + // Only the Chief should run this Op and it is guaranteed to be in + // a consistent state so the stamps must always match. + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); + + Tensor* output_used_handlers_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output("used_handlers_mask", {num_handlers_}, + &output_used_handlers_t)); + auto output_used_handlers = output_used_handlers_t->vec(); + + Tensor* output_num_used_handlers_t = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("num_used_handlers", {}, + &output_num_used_handlers_t)); + int handler_idx = 0; + std::vector used_handlers = ensemble_resource->GetUsedHandlers(); + output_num_used_handlers_t->scalar()() = used_handlers.size(); + for (int64 i = 0; i < num_handlers_; ++i) { + if (handler_idx >= used_handlers.size() || + used_handlers[handler_idx] > i) { + output_used_handlers(i) = false; + } else { + OP_REQUIRES(context, used_handlers[handler_idx] == i, + errors::InvalidArgument("Handler IDs should be sorted.")); + ++handler_idx; + output_used_handlers(i) = true; + } + } + } + + private: + int64 num_handlers_; +}; + REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeEnsembleResource); REGISTER_KERNEL_BUILDER( @@ -155,5 +210,7 @@ REGISTER_KERNEL_BUILDER(Name("TreeEnsembleSerialize").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("TreeEnsembleDeserialize").Device(DEVICE_CPU), TreeEnsembleDeserializeOp); +REGISTER_KERNEL_BUILDER(Name("TreeEnsembleUsedHandlers").Device(DEVICE_CPU), + TreeEnsembleUsedHandlersOp); } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 0f4c2298f56be48bb32f52d5d44cff8afe284f1e..0b28f81e7ca9a1228adc5bde19c429265e0aa9b8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -253,7 +253,7 @@ class CreateQuantileAccumulatorOp : public OpKernel { private: float epsilon_; int32 num_quantiles_; - // An upperbound on the number of enteries that the summaries might have + // An upper bound on the number of entries that the summaries might have // for a feature. int64 max_elements_; bool generate_quantiles_; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 7f8dea1d3c2a04b725843f6e2932a0cdfbc7733c..1bfeed306641111718984b2097512e5ec3fa8630 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -361,27 +361,10 @@ class GrowTreeEnsembleOp : public OpKernel { // Increment attempt stats. ensemble_resource->IncrementAttempts(); - // In case we want to do feature selection and we have reached the limit, - // build a list of handlers used so far to avoid adding new features. - std::vector allowed_handlers; - if (learner_config_.constraints().max_number_of_unique_feature_columns() > - 0) { - allowed_handlers = ensemble_resource->GetUsedHandlers(); - // TODO(soroush): We can disable handlers that are not going to be used to - // avoid unnecessary computations. - if (allowed_handlers.size() < - learner_config_.constraints() - .max_number_of_unique_feature_columns()) { - // We have not reached the limit yet. Empty the list of allow features - // which means we can keep adding new features. - allowed_handlers.clear(); - } - } - // Find best splits for each active partition. std::map best_splits; - FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list, - gains_list, splits_list, &best_splits); + FindBestSplitsPerPartition(context, partition_ids_list, gains_list, + splits_list, &best_splits); // No-op if no new splits can be considered. if (best_splits.empty()) { @@ -422,19 +405,12 @@ class GrowTreeEnsembleOp : public OpKernel { // and finds the best split for each partition. void FindBestSplitsPerPartition( OpKernelContext* const context, - const std::vector& allowed_handlers, // Empty means all handlers. const OpInputList& partition_ids_list, const OpInputList& gains_list, const OpInputList& splits_list, std::map* best_splits) { // Find best split per partition going through every feature candidate. // TODO(salehay): Is this worth parallelizing? for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { - if (!allowed_handlers.empty()) { - if (!std::binary_search(allowed_handlers.begin(), - allowed_handlers.end(), handler_id)) { - continue; - } - } const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); const auto& splits = splits_list[handler_id].vec(); diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 131bd48562a55a08981ac73277e93024db0d85d3..3028c2281705bd7e34b212332160d25386559d4e 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -15,17 +15,6 @@ load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # Utils cc_library( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index cd925f6b65e569538212e9c26aef0abc8482960b..794ba2bcb0aafa26c5e1c90fcd66caf9dd5bf7d5 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -137,7 +137,7 @@ struct NodeStats { Eigen::MatrixXf hessian = TensorToEigenMatrix(grad_stats.second.t, grad_dim, grad_dim); // I is an identity matrix. - // The gain in general form is -g^T (H+l2 I)^-1 g. + // The gain in general form is g^T (H+l2 I)^-1 g. // The node weights are -(H+l2 I)^-1 g. Eigen::MatrixXf identity; identity.setIdentity(grad_dim, grad_dim); @@ -240,7 +240,7 @@ struct NodeStats { // given regularized Hessian and gradient vector g. void CalculateWeightAndGain(const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g) { - // The gain in general form is -g^T (Hessian_and_regularization)^-1 g. + // The gain in general form is g^T (Hessian_and_regularization)^-1 g. // The node weights are -(Hessian_and_regularization)^-1 g. Eigen::VectorXf weight; // If we want to calculate x = K^-1 v, instead of explicitly calculating diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index cf4f9a097a3368465fd4d9afb981bbaa68b4df49..35b059f3496dbc8fb2b3d4fe6ec6b55a9d73dd0c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -54,7 +54,7 @@ Status BatchFeatures::Initialize( TF_CHECK_AND_RETURN_IF_ERROR( dense_float_feature.dim_size(1) == 1, errors::InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = ", + "Dense float features may not be multivalent: dim_size(1) = ", dense_float_feature.dim_size(1))); dense_float_feature_columns_.emplace_back(dense_float_feature); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index da5e7448519cb7f4092f7bbbe1b526271008ec22..a3b1b013e3a40116f74d6ed2df78d87ed3a11ac7 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -48,9 +48,9 @@ class BatchFeatures { Status GetFeatureColumnSizes(int64* const num_dense_float_features, int64* const num_sparse_float_features, int64* const num_sparse_int_features) const { - QCHECK_NE(num_dense_float_features, nullptr); - QCHECK_NE(num_sparse_float_features, nullptr); - QCHECK_NE(num_sparse_int_features, nullptr); + QCHECK_NE(num_dense_float_features, static_cast(nullptr)); + QCHECK_NE(num_sparse_float_features, static_cast(nullptr)); + QCHECK_NE(num_sparse_int_features, static_cast(nullptr)); *num_dense_float_features = dense_float_feature_columns_.size(); *num_sparse_float_features = sparse_float_feature_columns_.size(); *num_sparse_int_features = sparse_int_feature_columns_.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 609519e8b1153a27d987c5f9ca9bfcc9ee6717d6..cfe9101e7435cd798569f3e52a87fc8ed7b6a239 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -59,7 +59,7 @@ TEST_F(BatchFeaturesTest, DenseFloatFeatures_Multivalent) { BatchFeatures batch_features(1); auto dense_vec = AsTensor({3.0f, 7.0f}, {1, 2}); auto expected_error = InvalidArgument( - "Dense float features may not be multi-valent: dim_size(1) = 2"); + "Dense float features may not be multivalent: dim_size(1) = 2"); EXPECT_EQ(expected_error, batch_features.Initialize({dense_vec}, {}, {}, {}, {}, {}, {})); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index db34db998a7442c69f2ab468f4557d991429f4ee..ce67db797ded54f5023eaa89369d4781aad31a7c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -54,7 +54,7 @@ Status DropoutUtils::DropOutTrees( if (probability_of_skipping_dropout < 0 || probability_of_skipping_dropout > 1) { return errors::InvalidArgument( - "Probability of skiping dropout must be in [0,1] range"); + "Probability of skipping dropout must be in [0,1] range"); } const auto num_trees = weights.size(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h index 928bfbfe5c9394ab4083aabced4c8e1149bb10aa..77c16da5410fe65b20839c7b6bc677067d7ff297 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h @@ -66,7 +66,7 @@ class DropoutUtils { // Current weights and num_updates will be updated as a result of this // func std::vector* current_weights, - // How many weight assignements have been done for each tree already. + // How many weight assignments have been done for each tree already. std::vector* num_updates); }; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc index 0138aae3dbd3773241cb6644db625b99f9bf1372..cc7604745e6bb90837eeca1123faa88dc914e4fc 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc @@ -34,7 +34,7 @@ TEST_F(SparseColumnIterableTest, Empty) { } TEST_F(SparseColumnIterableTest, Iterate) { - // 8 examples having 7 sparse features with the 3rd and 7th multi-valent. + // 8 examples having 7 sparse features with the 3rd and 7th multivalent. // This can be visualized like the following: // Instance | Sparse | // 0 | x | diff --git a/tensorflow/contrib/boosted_trees/ops/model_ops.cc b/tensorflow/contrib/boosted_trees/ops/model_ops.cc index 0786c4166410720e8d4d70960e5747ff111076d8..9d6343c7e80f369bf6a5465821c5f4bacb984cd0 100644 --- a/tensorflow/contrib/boosted_trees/ops/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/model_ops.cc @@ -110,5 +110,32 @@ stamp_token: Token to use as the new value of the resource stamp. tree_ensemble_config: Serialized proto of the ensemble. )doc"); +REGISTER_OP("TreeEnsembleUsedHandlers") + .Attr("num_all_handlers: int >= 0") + .Input("tree_ensemble_handle: resource") + .Input("stamp_token: int64") + .Output("num_used_handlers: int64") + .Output("used_handlers_mask: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + c->set_output(0, c->Scalar()); + int num_all_handlers; + c->GetAttr("num_all_handlers", &num_all_handlers).IgnoreError(); + c->set_output(1, {c->Vector(num_all_handlers)}); + + return Status::OK(); + }) + .Doc(R"doc( +Returns the mask of used handlers along with the number of non-zero elements in +this mask. Used in feature selection. + +tree_ensemble_handle: Handle to the tree ensemble. +stamp_token: Token to use as the new value of the resource stamp. +num_used_handlers: number of feature column handlers used in the model. +used_handlers_mask: A boolean vector of showing which handlers are used in the + model. +)doc"); + } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index ae99d53a2cf805d70d60746cd44f73f7fd9dc6e2..6aa52463987b55a54b7308765920cbe94c15b8d1 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -272,6 +272,20 @@ REGISTER_OP("Quantiles") .Input("sparse_indices: num_sparse_features * int64") .Output("dense_quantiles: num_dense_features * int32") .Output("sparse_quantiles: num_sparse_features * int32") + .SetShapeFn([](InferenceContext* c) { + int num_dense_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features)); + int num_sparse_features; + TF_RETURN_IF_ERROR( + c->GetAttr("num_sparse_features", &num_sparse_features)); + // Set output shapes (dense_quantiles and sparse_quantiles) by the + // relevant inputs (dense_values and sparse_values). Note that the output + // has an additional dimension for dimension_ids. + for (int i = 0; i < num_dense_features + num_sparse_features; ++i) { + c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 2})); + } + return Status::OK(); + }) .Doc(R"doc( Computes quantile for each a given list of dense and sparse feature values using the given buckets. diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index 9a61e163eb5ff51dc75de4e40e0f43b090d03c0c..b07f0a4314246eea63764bb6d5e166dd720644fb 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -4,17 +4,6 @@ exports_files(["LICENSE"]) load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "learner_proto", srcs = [ diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 4407c4d981785a279b6296f4726a221cacb4c5b1..81411aa84ae848cfaa1392e82a1e38c3df19cdb6 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,7 +53,7 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the dimensiong + // If feature column is multivalent, this holds the index of the dimension // for the split. Defaults to 0. int32 dimension_id = 5; float threshold = 2; diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 27c288bbf78b3b593d0807e92ac7fd9afc4d2725..63b9c5fddf0d9967d53077608664b59d9ae00481 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -310,6 +310,22 @@ class ModelOpsTest(test_util.TensorFlowTestCase): # The third tree was added after the save. self.assertAllClose(result.eval(), [[-1.1], [-1.1]]) + def testUsedHandlers(self): + with self.test_session(): + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree_ensemble_config.growing_metadata.used_handler_ids.append(1) + tree_ensemble_config.growing_metadata.used_handler_ids.append(5) + stamp_token = 3 + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=stamp_token, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="create_tree") + resources.initialize_resources(resources.shared_resources()).run() + result = model_ops.tree_ensemble_used_handlers( + tree_ensemble_handle, stamp_token, num_all_handlers=6) + self.assertAllEqual([0, 1, 0, 0, 0, 1], result.used_handlers_mask.eval()) + self.assertEqual(2, result.num_used_handlers.eval()) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf..cf55759aaabfb265466f4bbf8b2806d4347ca0b1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -120,8 +120,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): """Sets up the prediction tests. Create a batch of two examples having one dense float, two sparse float - single valued, one sparse float multidimensionl and one sparse int features. - The data looks like the following: + single valued, one sparse float multidimensional and one sparse int + features. The data looks like the following: | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM | 0 | 7 | -3 | | 9,1 | __, 5.0 | 1 | -2 | | 4 | | 3, ___ @@ -810,7 +810,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # building. This tree should never be dropped. num_trees = 10 with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 10 trees with some weights. for i in range(0, num_trees): @@ -951,7 +951,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testDropOutZeroProb(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Add 1000 trees with some weights. for i in range(0, 999): @@ -994,7 +994,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def testAveragingAllTrees(self): with self.test_session(): - # Empty tree ensenble. + # Empty tree ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() adjusted_tree_ensemble_config = ( tree_config_pb2.DecisionTreeEnsembleConfig()) diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 81f58de28cbe98bb996c6665114eeb0030ee52f9..074623699d9d82f999c9cbc483ddcd8a959f4bad 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -482,7 +482,7 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): """Sets up the quantile op tests. Create a batch of 4 examples having 2 dense and 4 sparse features. - Forth sparse feature is multivalent (3 dimensional) + Fourth sparse feature is multivalent (3 dimensional) The data looks like this | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 8ca1aabacaf53b66aaba184962922294427d6803..3e524efbeac74ff754d63cae92b3e194411cb2de 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -1588,7 +1588,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual( 2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates) - def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self): + def testGrowExistingEnsembleTreeWithFeatureSelectionUsedHandlers(self): """Test growing a tree with feature selection.""" with self.test_session() as session: # Create existing ensemble with one root split and one bias tree. @@ -1649,7 +1649,6 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): num_trees_attempted: 2 num_layers_attempted: 2 used_handler_ids: 2 - used_handler_ids: 5 } """, tree_ensemble_config) tree_ensemble_handle = model_ops.tree_ensemble_variable( @@ -1668,183 +1667,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - # There are 2 handler_ids in used_handler_ids already but one of them - # is handler 2, so we can still grow trees. - learner_config.constraints.max_number_of_unique_feature_columns = 2 - learner_config = learner_config.SerializeToString() - # Prepare handler inputs. - handler1_partitions = np.array([0], dtype=np.int32) - handler1_gains = np.array([7.62], dtype=np.float32) - handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] - handler2_partitions = np.array([0], dtype=np.int32) - handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] - handler3_partitions = np.array([0], dtype=np.int32) - handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] - - # Grow tree ensemble. - grow_op = training_ops.grow_tree_ensemble( - tree_ensemble_handle, - stamp_token=0, - next_stamp_token=1, - learning_rate=1, - partition_ids=[ - handler1_partitions, handler2_partitions, handler3_partitions - ], - gains=[handler1_gains, handler2_gains, handler3_gains], - splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, - dropout_seed=123, - center_bias=True) - session.run(grow_op) - - # Expect a new tree to be added with the split from handler 1. - _, serialized = session.run( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)) - tree_ensemble_config.ParseFromString(serialized) - self.assertEqual(3, len(tree_ensemble_config.trees)) - self.assertEqual( - 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) - - def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self): - """Test growing a tree with feature selection with empty ensemble.""" - with self.test_session() as session: - # Create existing ensemble with one root split and one bias tree. - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble_config.SerializeToString(), - name="tree_ensemble") - resources.initialize_resources(resources.shared_resources()).run() - - # Prepare learner config. - learner_config = _gen_learner_config( - num_classes=2, - l1_reg=0, - l2_reg=0, - tree_complexity=0, - max_depth=1, - min_node_weight=0, - pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - learner_config.constraints.max_number_of_unique_feature_columns = 2 - learner_config = learner_config.SerializeToString() - # Prepare handler inputs. - handler1_partitions = np.array([0], dtype=np.int32) - handler1_gains = np.array([7.62], dtype=np.float32) - handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] - handler2_partitions = np.array([0], dtype=np.int32) - handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] - handler3_partitions = np.array([0], dtype=np.int32) - handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] - - # Grow tree ensemble. - grow_op = training_ops.grow_tree_ensemble( - tree_ensemble_handle, - stamp_token=0, - next_stamp_token=1, - learning_rate=1, - partition_ids=[ - handler1_partitions, handler2_partitions, handler3_partitions - ], - gains=[handler1_gains, handler2_gains, handler3_gains], - splits=[handler1_split, handler2_split, handler3_split], - learner_config=learner_config, - dropout_seed=123, - center_bias=True) - session.run(grow_op) - - _, serialized = session.run( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)) - tree_ensemble_config.ParseFromString(serialized) - self.assertEqual(1, len(tree_ensemble_config.trees)) - self.assertEqual( - 1, len(tree_ensemble_config.growing_metadata.used_handler_ids)) - - def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self): - """Test growing a tree with feature selection with empty ensemble.""" - with self.test_session() as session: - # Create existing ensemble with one root split and one bias tree. - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" - trees { - nodes { - leaf { - vector { - value: -0.32 - value: 0.28 - } - } - } - } - trees { - nodes { - categorical_id_binary_split { - feature_column: 3 - feature_id: 7 - left_id: 1 - right_id: 2 - } - node_metadata { - gain: 1.3 - } - } - nodes { - leaf { - sparse_vector { - index: 0 - value: 2.3 - } - } - } - nodes { - leaf { - sparse_vector { - index: 0 - value: -0.9 - } - } - } - } - tree_weights: 0.7 - tree_weights: 1 - tree_metadata { - num_tree_weight_updates: 1 - num_layers_grown: 1 - is_finalized: true - } - tree_metadata { - num_tree_weight_updates: 5 - num_layers_grown: 1 - is_finalized: true - } - growing_metadata { - num_trees_attempted: 2 - num_layers_attempted: 2 - used_handler_ids: 4 - used_handler_ids: 5 - } - """, tree_ensemble_config) - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble_config.SerializeToString(), - name="tree_ensemble") - resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = _gen_learner_config( - num_classes=2, - l1_reg=0, - l2_reg=0, - tree_complexity=0, - max_depth=1, - min_node_weight=0, - pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) - learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config.constraints.max_number_of_unique_feature_columns = 3 learner_config = learner_config.SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1876,12 +1700,10 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): _, serialized = session.run( model_ops.tree_ensemble_serialize(tree_ensemble_handle)) tree_ensemble_config.ParseFromString(serialized) - # We can't grow a tree since we have reached the limit of 2 unique - # features [4, 5] and the only available splits are from - # handlers [0, 1, 2]. - self.assertEqual(2, len(tree_ensemble_config.trees)) - self.assertEqual( - 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + self.assertEqual(3, len(tree_ensemble_config.trees)) + # 2 was already used. handler 0 is being added in this tree. + self.assertAllEqual( + [0, 2], tree_ensemble_config.growing_metadata.used_handler_ids) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py index 7a5f509047d46549ba81039a23d29ec987ca7920..25b2c9e2fd72bd018717e8a87fce726f26bad968 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/model_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/model_ops.py @@ -25,6 +25,7 @@ from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensem from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_serialize # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_stamp_token +from tensorflow.contrib.boosted_trees.python.ops.gen_model_ops import tree_ensemble_used_handlers # pylint: enable=unused-import from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 97d57e8b23608d4c3a8719426a75056fc6417d1d..1b184d296b329cee481db67992e77d1e33e18035 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -184,7 +184,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): """Finalizes quantile summary stream and resets it for next iteration. Args: - stamp_token: Exepcted current token. + stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: A list of quantiles or approximate boundaries. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index f0b66dcbbe1c5167b9993e66b30b1dc8a839c380..4bde7f3e33d6f8b295cd35cb32bbbccecf8a2b87 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -23,7 +23,6 @@ import copy from tensorflow.contrib import learn from tensorflow.contrib import stateless - from tensorflow.contrib.boosted_trees.lib.learner.batch import categorical_split_handler from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 @@ -57,6 +56,8 @@ PREDICTIONS = "predictions" PARTITION_IDS = "partition_ids" NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" +NUM_USED_HANDLERS = "num_used_handlers" +USED_HANDLERS_MASK = "used_handlers_mask" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -70,7 +71,8 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): +def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, + used_handlers): """Returns predictions for the given logits and n_classes. Args: @@ -79,6 +81,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): that contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. + used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a + boolean mask.. Returns: A dict of predictions. @@ -89,6 +93,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): result[PARTITION_IDS] = partition_ids result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees + result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers + result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask return result @@ -134,7 +140,7 @@ class _OpRoundRobinStrategy(object): return task -def extract_features(features, feature_columns): +def extract_features(features, feature_columns, use_core_columns): """Extracts columns from a dictionary of features. Args: @@ -167,7 +173,11 @@ def extract_features(features, feature_columns): transformed_features = collections.OrderedDict() for fc in feature_columns: # pylint: disable=protected-access - if isinstance(fc, feature_column_lib._EmbeddingColumn): + if use_core_columns: + # pylint: disable=protected-access + tensor = fc_core._transform_features(features, [fc])[fc] + transformed_features[fc.name] = tensor + elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( features, [fc], @@ -258,7 +268,8 @@ class GradientBoostedDecisionTreeModel(object): learner_config, features, logits_dimension, - feature_columns=None): + feature_columns=None, + use_core_columns=False): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -331,8 +342,9 @@ class GradientBoostedDecisionTreeModel(object): if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, - sparse_float_shapes, sparse_int_indices, sparse_int_values, - sparse_int_shapes) = extract_features(features, self._feature_columns) + sparse_float_shapes, sparse_int_indices, + sparse_int_values, sparse_int_shapes) = extract_features( + features, self._feature_columns, use_core_columns) logging.info("Active Feature Columns: " + str(fc_names)) self._fc_names = fc_names self._dense_floats = dense_floats @@ -361,6 +373,13 @@ class GradientBoostedDecisionTreeModel(object): """ ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) + num_handlers = ( + len(self._dense_floats) + len(self._sparse_float_shapes) + + len(self._sparse_int_shapes)) + # Used during feature selection. + used_handlers = model_ops.tree_ensemble_used_handlers( + ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) + # We don't need dropout info - we can always restore it based on the # seed. apply_dropout, seed = _dropout_params(mode, ensemble_stats) @@ -395,7 +414,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats) + ensemble_stats, used_handlers) def predict(self, mode): """Returns predictions given the features and mode. @@ -710,12 +729,28 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack(active_handlers_current_layer, - array_ops.ones( - [len(handlers)], dtype=dtypes.bool)) + active_handlers = array_ops.stack([ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) + if self._learner_config.constraints.max_number_of_unique_feature_columns: + target = ( + self._learner_config.constraints.max_number_of_unique_feature_columns) + + def _feature_selection_active_handlers(): + # The active list for current and the next iteration. + used_handlers = array_ops.reshape(predictions_dict[USED_HANDLERS_MASK], + [-1, 1]) + used_handlers = array_ops.concat([used_handlers, used_handlers], axis=1) + return math_ops.logical_and(used_handlers, active_handlers) + + active_handlers = ( + control_flow_ops.cond(predictions_dict[NUM_USED_HANDLERS] >= target, + _feature_selection_active_handlers, + lambda: active_handlers)) + # Prepare empty gradients and hessians when handlers are not ready. empty_hess_shape = [1] + hessian_shape.as_list() empty_grad_shape = [1] + gradient_shape.as_list() diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index dba51d4f527792d2a8dedc693f74c07119fd231d..f9c22283b7f5136777bfa60a12c94974adfbd245 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -27,9 +27,11 @@ from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.boosted_trees.python.utils import losses +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn + from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -43,10 +45,42 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keep_dims=True) + math_ops.square(predictions - label), 1, keepdims=True) return loss +def _append_to_leaf(leaf, c_id, w): + """Helper method for building tree leaves. + + Appends weight contributions for the given class index to a leaf node. + + Args: + leaf: leaf node to append to. + c_id: class Id for the weight update. + w: weight contribution value. + """ + leaf.sparse_vector.index.append(c_id) + leaf.sparse_vector.value.append(w) + + +def _set_float_split(split, feat_col, thresh, l_id, r_id): + """Helper method for building tree float splits. + + Sets split feature column, threshold and children. + + Args: + split: split node to update. + feat_col: feature column for the split. + thresh: threshold to split on forming rule x <= thresh. + l_id: left child Id. + r_id: right child Id. + """ + split.feature_column = feat_col + split.threshold = thresh + split.left_id = l_id + split.right_id = r_id + + class GbdtTest(test_util.TensorFlowTestCase): def setUp(self): @@ -67,7 +101,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, - sparse_int_shapes) = (gbdt_batch.extract_features(features, None)) + sparse_int_shapes) = ( + gbdt_batch.extract_features(features, None, use_core_columns=False)) self.assertEqual(len(fc_names), 3) self.assertAllEqual(fc_names, ["dense_float", "sparse_float", "sparse_int"]) @@ -116,8 +151,9 @@ class GbdtTest(test_util.TensorFlowTestCase): "sparse_categorical", hash_bucket_size=1000000)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, - sparse_int_shapes) = (gbdt_batch.extract_features( - features, feature_columns)) + sparse_int_shapes) = ( + gbdt_batch.extract_features( + features, feature_columns, use_core_columns=False)) self.assertEqual(len(fc_names), 3) self.assertAllEqual(fc_names, ["dense_float", "sparse_float", "sparse_categorical"]) @@ -142,6 +178,41 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllEqual(sparse_int_shapes[0].eval(), features["sparse_categorical"].dense_shape.eval()) + def testExtractFeaturesFromCoreFeatureColumns(self): + """Tests feature extraction when using core columns.""" + with self.test_session(): + features = {} + # Sparse float column does not exist in core, so only dense numeric and + # categorical. + features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32) + features["sparse_categorical"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) + + feature_columns = set() + feature_columns.add(core_feature_column.numeric_column("dense_float")) + feature_columns.add( + core_feature_column.categorical_column_with_hash_bucket( + "sparse_categorical", hash_bucket_size=1000000)) + (fc_names, dense_floats, _, _, _, sparse_int_indices, sparse_int_values, + sparse_int_shapes) = ( + gbdt_batch.extract_features( + features, feature_columns, use_core_columns=True)) + self.assertEqual(len(fc_names), 2) + self.assertAllEqual(fc_names, ["dense_float", "sparse_categorical"]) + self.assertEqual(len(dense_floats), 1) + self.assertEqual(len(sparse_int_indices), 1) + self.assertEqual(len(sparse_int_values), 1) + self.assertEqual(len(sparse_int_shapes), 1) + self.assertAllEqual(dense_floats[0].eval(), + features["dense_float"].eval()) + self.assertAllEqual(sparse_int_indices[0].eval(), + features["sparse_categorical"].indices.eval()) + self.assertAllEqual(sparse_int_values[0].eval(), [397263, 397263]) + self.assertAllEqual(sparse_int_shapes[0].eval(), + features["sparse_categorical"].dense_shape.eval()) + def testTrainFnChiefNoBiasCentering(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -917,6 +988,350 @@ class GbdtTest(test_util.TensorFlowTestCase): output.trees[0].nodes[2].leaf.sparse_vector.value[0], atol=1e-4, rtol=1e-4) + def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + # Feature 1 is predictive but it won't be used because we have reached the + # limit of num_used_handlers >= max_number_of_unique_feature_columns + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([True, False], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + # On second run, expect a trivial split to be chosen to basically + # predict the average. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + dense_float_binary_split { + feature_column: 0 + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: -0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + + def testTrainFnChiefFeatureSelectionWithGoodSplits(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + # Feature 1 is predictive and is in our selected features so it will be + # used even when we're at the limit. + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([False, True], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + dense_float_binary_split { + feature_column: 1 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0.5 + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + + def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self): + """Tests the train function running on chief with feature selection.""" + with self.test_session() as sess: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree = tree_ensemble_config.trees.add() + + _set_float_split(tree.nodes.add() + .sparse_float_binary_split_default_right.split, 2, 4.0, + 1, 2) + _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) + _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) + tree_ensemble_config.tree_weights.append(1.0) + metadata = tree_ensemble_config.tree_metadata.add() + metadata.is_finalized = False + metadata.num_layers_grown = 1 + tree_ensemble_config = tree_ensemble_config.SerializeToString() + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config=tree_ensemble_config, + name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.max_number_of_unique_feature_columns = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + # Both features will be disabled since the feature selection limit is + # already reached. + features["dense_float_0"] = array_ops.ones([4, 1], dtypes.float32) + features["dense_float_1"] = array_ops.constant([0, 0, 1, 1], + dtypes.float32) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": + predictions, + "predictions_no_dropout": + predictions, + "partition_ids": + partition_ids, + "ensemble_stamp": + ensemble_stamp, + "num_trees": + 12, + # We have somehow reached our limit 1. Both of the handlers will be + # disabled. + "num_used_handlers": + array_ops.constant(1, dtype=dtypes.int64), + "used_handlers_mask": + array_ops.constant([False, False], dtype=dtypes.bool), + } + + labels = array_ops.constant([0, 0, 1, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertEquals(output.growing_metadata.num_layers_attempted, 1) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + # Make sure the trees are not modified, but the num_layers_attempted is + # incremented so that eventually the training stops. + self.assertEquals(len(output.trees), 1) + self.assertEquals(len(output.trees[0].nodes), 3) + + self.assertEquals(output.growing_metadata.num_layers_attempted, 2) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/resources/BUILD b/tensorflow/contrib/boosted_trees/resources/BUILD index 9fc101612f1e2a6bf6c5d86ea8c7199936dbb069..c0651868453d40d57e842862855f89e6845c507f 100644 --- a/tensorflow/contrib/boosted_trees/resources/BUILD +++ b/tensorflow/contrib/boosted_trees/resources/BUILD @@ -9,17 +9,6 @@ package( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - cc_library( name = "stamped_resource", hdrs = ["stamped_resource.h"], diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 3ebf28ea442edf87815c39971ae9e01a2a8aae9a..94aeb2c7bb48c6eddb6c7894f8bf6f1567470113 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -126,7 +126,8 @@ class DecisionTreeEnsembleResource : public StampedResource { return; } used_ids->Add(handler_id); - std::rotate(first, used_ids->end() - 1, used_ids->end()); + // Keep the list of used handlers sorted. + std::sort(used_ids->begin(), used_ids->end()); } std::vector GetUsedHandlers() const { diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index fe8bd072afd43a64fa62a65bd8900b5a98dbe761..f3a75e8688ece19a6e6fd53ee9faf7f4144d76cf 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -14,18 +14,6 @@ load( "tf_py_test", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_gen_op_libs( op_lib_names = ["bigquery_reader_ops"], deps = [ diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 56f930a9a8d32c5c3a025163ef56c9562f17d864..ff46f0daa80a70badedf73e15bfaf4dca85fdd89 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -20,20 +20,6 @@ load( "tf_proto_library", ) -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_kernel_library( name = "bigquery_reader_ops", srcs = ["bigquery_reader_ops.cc"], @@ -73,6 +59,7 @@ tf_cc_test( ], deps = [ ":bigquery_table_accessor", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index e9b79a066def566096d6c3f3745974423e3371d1..7416eb19d3324fad84876cde5353bc25bac8f648 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" #include "tensorflow/core/platform/test.h" @@ -28,8 +29,8 @@ constexpr char kTestProject[] = "test-project"; constexpr char kTestDataset[] = "test-dataset"; constexpr char kTestTable[] = "test-table"; -bool HasSubstr(const string& base, const string& substr) { - bool ok = StringPiece(base).contains(substr); +bool HasSubstr(StringPiece base, StringPiece substr) { + bool ok = str_util::StrContains(base, substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 80e18a43a71cc9d6c9e2ccf5836e50c6427a30f6..c239e6f8f960910cee14e1df7c4678c643496f54 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -10,19 +10,6 @@ package( licenses(["notice"]) # Apache 2.0 -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - py_library( name = "cluster_resolver_pip", srcs = [ @@ -30,6 +17,7 @@ py_library( "python/training/__init__.py", ], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":cluster_resolver_py", ":gce_cluster_resolver_py", @@ -109,5 +97,6 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training", ], + grpc_enabled = True, main = "python/training/tpu_cluster_resolver_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index b04822fa9d66465e34a545d3b00c399bbb196514..1c480b25134b1e54200e0ddb780bd7bb0f122341 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -53,11 +53,16 @@ class ClusterResolver(object): raise NotImplementedError( 'cluster_spec is not implemented for {}.'.format(self)) + @abc.abstractmethod + def master(self): + """...""" + raise NotImplementedError('master is not implemented for {}.'.format(self)) + class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec): + def __init__(self, cluster_spec, master=''): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() @@ -65,10 +70,18 @@ class SimpleClusterResolver(ClusterResolver): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec + if not isinstance(master, str): + raise TypeError('master must be a string.') + self._master = master + def cluster_spec(self): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec + def master(self): + """Returns the master address to use when creating a session.""" + return self._master + class UnionClusterResolver(ClusterResolver): """Performs a union on underlying ClusterResolvers. @@ -87,9 +100,13 @@ class UnionClusterResolver(ClusterResolver): Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. + ValueError: If there are no arguments passed. """ super(UnionClusterResolver, self).__init__() + if not args: + raise ValueError('At least one ClusterResolver is required.') + for cluster_resolver in args: if not isinstance(cluster_resolver, ClusterResolver): raise TypeError('All arguments must be a sub-class of ' @@ -169,3 +186,7 @@ class UnionClusterResolver(ClusterResolver): merged_cluster[job_name].update(task_dict) return ClusterSpec(merged_cluster) + + def master(self): + """master returns the master address from the first cluster resolver.""" + return self._cluster_resolvers[0].master() diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index dbfb77723cdaab66e29bb41b764593bb5fd61b35..d9c97d53eb3663f6ab2f7b40395592dc7638b896 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -234,5 +234,7 @@ class UnionClusterResolverTest(test.TestCase): self._verifyClusterSpecEquality(cluster_spec, expected_proto) +# TODO(saeta): Include tests for master resolution + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index d6f2eced93ba4fda5ac27f9412b6f729981f4f40..3f5824128948453634bc5e5a7d6fdeedae60f5bd 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -134,3 +134,6 @@ class GceClusterResolver(ClusterResolver): worker_list.sort() return ClusterSpec({self._job_name: worker_list}) + + def master(self): + return '' diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index a6a6e642e4e4c721b94821a70d55d6fe931347d6..5a2771229d9ffe2b5b389d1077fe02a230e9a4c0 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,12 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from six.moves.urllib.request import Request from six.moves.urllib.request import urlopen from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat _GOOGLE_API_CLIENT_INSTALLED = True try: @@ -33,6 +35,9 @@ except ImportError: _GOOGLE_API_CLIENT_INSTALLED = False +_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' + + class TPUClusterResolver(ClusterResolver): """Cluster Resolver for Google Cloud TPUs. @@ -46,13 +51,32 @@ class TPUClusterResolver(ClusterResolver): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) resp = urlopen(req) - return resp.read() + return compat.as_bytes(resp.read()) + + def _shouldResolve(self): + if (self._tpu == compat.as_bytes('') or + self._tpu == compat.as_bytes('local') or + self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('grpc://'))): + return False + return True + + @staticmethod + def _inGke(): + """When running in GKE, the environment variable will be set.""" + return _GKE_ENV_VARIABLE in os.environ + + @staticmethod + def _gkeMaster(): + return os.environ[_GKE_ENV_VARIABLE].split(',')[0] def __init__(self, - tpu_names, + tpu=None, zone=None, project=None, - job_name='tpu_worker', + job_name='worker', + coordinator_name=None, + coordinator_address=None, credentials='default', service=None): """Creates a new TPUClusterResolver object. @@ -61,7 +85,11 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - tpu_names: A list of names of the target Cloud TPUs. + tpu: Either a string, or a list of strings corresponding to the TPUs to + use. If the single string is the empty string, the string 'local', or a + string that begins with 'grpc://' or '/bns', then it is assumed to not + correspond with a Cloud TPU and will instead be passed as the session + master and no ClusterSpec propagation will be done. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. @@ -69,6 +97,12 @@ class TPUClusterResolver(ClusterResolver): empty, we will try to discover the project name of the GCE VM from the GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. + coordinator_name: The name to use for the coordinator. Set to None if the + coordinator should not be included in the computed ClusterSpec. + coordinator_address: The address of the coordinator (typically an ip:port + pair). If set to None, a TF server will be started. If coordinator_name + is None, a TF server will not be started even if coordinator_address is + None. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client service: The GCE API object returned by the googleapiclient.discovery @@ -77,29 +111,48 @@ class TPUClusterResolver(ClusterResolver): Raises: ImportError: If the googleapiclient is not installed. + ValueError: If no TPUs are specified. """ + if isinstance(tpu, list): + if not tpu: + raise ValueError('At least one TPU must be specified.') + if len(tpu) != 1: + raise NotImplementedError( + 'Using multiple TPUs in a single session is not yet implemented') + tpu = tpu[0] + + in_gke = self._inGke() + # When using GKE with Cloud TPUs, the env variable will be set. + if tpu is None and in_gke: + tpu = self._gkeMaster() + + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes + self._job_name = job_name + self._credentials = credentials - if not project: - project = self._requestComputeMetadata('/project/project-id') + should_resolve = self._shouldResolve() - if not zone: - zone_path = self._requestComputeMetadata('/instance/zone') + if not project and should_resolve: + project = compat.as_str( + self._requestComputeMetadata('project/project-id')) + + if not zone and should_resolve: + zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) zone = zone_path.split('/')[-1] self._project = project self._zone = zone - self._tpu_names = tpu_names - self._job_name = job_name - self._credentials = credentials - if credentials == 'default': + if credentials == 'default' and should_resolve: if _GOOGLE_API_CLIENT_INSTALLED: self._credentials = GoogleCredentials.get_application_default() - if service is None: + if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver') + 'TPU cluster resolver. Execute: `pip install ' + '--upgrade google-api-python-client` to install with ' + 'pip.') self._service = discovery.build( 'tpu', 'v1alpha1', @@ -107,25 +160,42 @@ class TPUClusterResolver(ClusterResolver): else: self._service = service - def get_master(self): - """Get the ClusterSpec grpc master path. + self._coordinator_name = coordinator_name + if coordinator_name and not coordinator_address and (should_resolve or + in_gke): + self._start_local_server() + else: + self._coordinator_address = coordinator_address - This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the - ClusterSpec returned by the cluster_spec function. This is suitable for use - for the `master` argument in tf.Session() when you are using one TPU. + def master(self): + """Get the Master string to be used for the session. + + In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of + first instance in the ClusterSpec returned by the cluster_spec function. + + If a non-TPU name is used when constructing a TPUClusterResolver, that will + be returned instead (e.g. If the tpus argument's value when constructing + this TPUClusterResolver was 'grpc://10.240.1.2:8470', + 'grpc://10.240.1.2:8470' will be returned). Returns: - string, the grpc path of the first instance in the ClusterSpec. + string, the connection string to use when creating a session. Raises: ValueError: If none of the TPUs specified exists. """ + if not self._shouldResolve(): + return self._tpu + job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: raise ValueError('No TPUs exists with the specified names exist.') return 'grpc://' + job_tasks[0] + def get_master(self): + return self.master() + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. @@ -134,17 +204,73 @@ class TPUClusterResolver(ClusterResolver): Returns: A ClusterSpec containing host information returned from Cloud TPUs. + + Raises: + RuntimeError: If the provided TPU is not healthy. """ - worker_list = [] + ############################################################################ + # There are 5 potential cases this code must handle: + # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and + # a. Create a ClusterSpec that includes the coordinator job + # b. Create a ClusterSpec without the coordinator job. + # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of + # tasks and + # a. Create a ClusterSpec with the coordinator + # b. Create a ClusterSpec without the coordinator + # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. + ############################################################################ - for tpu_name in self._tpu_names: + if self._shouldResolve(): + # Case 1. full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, tpu_name) + self._project, self._zone, compat.as_text(self._tpu)) request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() - if 'health' in response and response['health'] == 'HEALTHY': + if 'health' in response and response['health'] != 'HEALTHY': + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, + response['health'])) + + if 'networkEndpoints' in response: + worker_list = [ + '%s:%s' % (endpoint['ipAddress'], endpoint['port']) + for endpoint in response['networkEndpoints'] + ] + else: + # Fall back to the deprecated response format instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) + worker_list = [instance_url] + + cluster_spec = {self._job_name: worker_list} + else: + if not self._tpu.startswith(compat.as_bytes('grpc://')): + # Case 3. + return server_lib.ClusterSpec({}) + # Case 2. + cluster_spec = {self._job_name: [self._tpu[len( + compat.as_bytes('grpc://')):]]} + + if self._coordinator_address: + # {1, 2}.a + cluster_spec[self._coordinator_name] = [self._coordinator_address] + + return server_lib.ClusterSpec(cluster_spec) + + def _start_local_server(self): + address = self._requestComputeMetadata('instance/network-interfaces/0/ip') + self._server = server_lib.Server( + { + 'local': ['0.0.0.0:0'] + }, protocol='grpc', config=None, start=True) + # self._server.target is of the form: grpc://ipaddress:port + target = compat.as_bytes(self._server.target) + splits = target.split(compat.as_bytes(':')) + assert len(splits) == 3, self._server.target + assert splits[0] == compat.as_bytes('grpc'), self._server.target + self._coordinator_port = compat.as_text(splits[2]) + self._coordinator_address = '%s:%s' % ( + address, compat.as_text(self._coordinator_port)) - return ClusterSpec({self._job_name: worker_list}) + def __deepcopy__(self, memo): + # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. + return self diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 4fd34629cf74f90869c77b8cb098d3c585a49404..dff7a03b6847fb6e159dc2fa9832fceb3dfe2d54 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -18,10 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib - +from tensorflow.python.util import compat mock = test.mock @@ -50,10 +52,12 @@ class MockNodeClass(object): def mock_request_compute_metadata(cls, *args, **kwargs): del cls, kwargs # Unused. - if args[0] == '/project/project-id': + if args[0] == 'project/project-id': return 'test-project' - elif args[0] == '/instance/zone': + elif args[0] == 'instance/zone': return 'projects/test-project/locations/us-central1-c' + elif args[0] == 'instance/network-interfaces/0/ip': + return '10.128.1.2' return '' @@ -71,18 +75,17 @@ class TPUClusterResolverTest(test.TestCase): expected_proto: Expected protobuf """ self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) - self.assertProtoEquals( - expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) - self.assertProtoEquals( - expected_proto, - server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) self.assertProtoEquals( expected_proto, - server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) + server_lib.ClusterSpec(cluster_spec).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_cluster_def()).as_cluster_def()) + self.assertProtoEquals(expected_proto, + server_lib.ClusterSpec( + cluster_spec.as_dict()).as_cluster_def()) - def mock_service_client( - self, - tpu_map=None): + def mock_service_client(self, tpu_map=None): if tpu_map is None: tpu_map = {} @@ -98,8 +101,7 @@ class TPUClusterResolverTest(test.TestCase): return mock_client - @mock.patch.object(TPUClusterResolver, - '_requestComputeMetadata', + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', mock_request_compute_metadata) def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { @@ -113,17 +115,27 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver( project=None, zone=None, - tpu_names=['test-tpu-1'], + tpu=['test-tpu-1'], credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) + service=self.mock_service_client(tpu_map=tpu_map), + coordinator_name='coordinator') actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + job { + name: 'coordinator' + tasks { key: 0 value: '10.128.1.2:%s' } + } + job { + name: 'worker' + tasks { key: 0 value: '10.1.2.3:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) - def testSimpleSuccessfulRetrieval(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', @@ -133,116 +145,229 @@ class TPUClusterResolverTest(test.TestCase): } tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu_names=['test-tpu-1'], + project=None, + zone=None, + tpu=['test-tpu-1'], + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testMultipleSuccessfulRetrieval(self): + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu=['test-tpu-1'], + coordinator_name='coordinator', + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' } - tasks { key: 1 value: '10.1.2.3:8470' } } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testHealthyTpuNodeRetrieval(self): + def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { - 'ipAddress': '10.7.8.9', - 'port': '8470', - 'health': 'UNHEALTHY' + 'health': 'HEALTHY', + 'networkEndpoints': [{ + 'ipAddress': '10.2.3.4', + 'port': 8470, + }] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + tpu='test-tpu-1', + coordinator_name='coordinator', + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { - name: 'tpu_worker' - tasks { - key: 0 - value: '10.1.2.3:8470' - } - } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master()) - def testGetMasterMultipleEntries(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] + } + } + + tpu_cluster_resolver = TPUClusterResolver( + tpu='test-tpu-1', + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map), + coordinator_name='coordinator') + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'coordinator', + tasks { key: 0 value: '10.128.1.2:%s'} + } + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + + def testPodResolutionNoCoordinator(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu='test-tpu-1', + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) - self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) def testGetMasterNoEntries(self): tpu_map = {} + with self.assertRaises(ValueError): + TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu=[], + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + # TODO(saeta): Convert to parameterized test when included in OSS TF. + def verifyShouldResolve(self, tpu, should_resolve): tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=[], + tpu=tpu, + coordinator_name=None, credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - with self.assertRaises(ValueError): - tpu_cluster_resolver.get_master() + service=self.mock_service_client(tpu_map={})) + self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(), + "TPU: '%s'" % tpu) + + def testShouldResolveNoName(self): + self.verifyShouldResolve('', False) + + def testShouldResolveLocal(self): + self.verifyShouldResolve('local', False) + + def testShouldResolveGrpc(self): + self.verifyShouldResolve('grpc://10.1.2.3:8470', False) + + def testShouldResolveBns(self): + self.verifyShouldResolve('/bns/foo/bar', False) + + def testShouldResolveName(self): + self.verifyShouldResolve('mytpu', True) + + def testShouldResolveList(self): + self.verifyShouldResolve(['myothertpu'], True) + + def testShouldResolveGrpcPrefix(self): + self.verifyShouldResolve('grpctpu', True) + + def testNoCallComputeMetadata(self): + tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') + self.assertEqual( + compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) + self.assertEqual( + server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec()) + + def testGkeEnvironment(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' + self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeMaster())) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 16317f538f3890661f1b59ea39fe67dcf04d0d0a..10f29deca08f2a70eeeb7c758f3268e0bce11e8c 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -31,10 +31,14 @@ option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF) option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF) option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) -option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) +# SIMD, MKL and MKLDNN options +option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions" OFF) +option(tensorflow_ENABLE_MKL_SUPPORT "Enable Intel MKL support" OFF) +option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires MKL enabled" OFF) + # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") @@ -124,8 +128,16 @@ endif() add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) - add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) - add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + # 64 bits + add_definitions(-DWIN64) + elseif(CMAKE_SIZEOF_VOID_P EQUAL 4) + # 32 bits + # temporary fix for #18241 + add_definitions(-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=std::int64_t) + endif() + add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11) + add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH) add_definitions(-DTF_COMPILE_LIBRARY) add_definitions(/bigobj /nologo /EHsc /GF /MP /Gm-) @@ -162,12 +174,21 @@ endif() # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) + include(CheckCXXCompilerFlag) + if (tensorflow_ENABLE_MKL_SUPPORT) + add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) + if (NOT tensorflow_ENABLE_MKLDNN_SUPPORT) + add_definitions(-DINTEL_MKL_ML) + endif() + endif() + CHECK_CXX_COMPILER_FLAG("-fopenmp" COMPILER_OPT_OPENMP_SUPPORT) + if (COMPILER_OPT_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") + endif() if (WIN32) - CHECK_CXX_COMPILER_FLAG("${tensorflow_WIN_CPU_SIMD_OPTIONS}" COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) + CHECK_CXX_COMPILER_FLAG(${tensorflow_WIN_CPU_SIMD_OPTIONS} COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) if(COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${tensorflow_WIN_CPU_SIMD_OPTIONS}") - else() - message(FATAL_ERROR "${tensorflow_WIN_CPU_SIMD_OPTIONS} not supported") endif() endif() endif() @@ -298,6 +319,43 @@ if(HAIKU) list(APPEND tensorflow_EXTERNAL_LIBRARIES network) endif() +if (tensorflow_ENABLE_MKL_SUPPORT) + if (WIN32) + find_path(MKL_HOME_PLATFORM mkl + PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + PATH_SUFFIXES windows) + set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) + set(MKL_LINK_DIRS + ${MKL_HOME_PLATFORM}/mkl/lib/intel64 + ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt + ${MKL_HOME_PLATFORM}/compiler/lib/intel64 + ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) + set(MKL_REDIST_DLL_DIRS + ${MKL_HOME_PLATFORM}/redist/intel64/mkl + ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt + ${MKL_HOME_PLATFORM}/redist/intel64/compiler) + list(APPEND tensorflow_EXTERNAL_LIBRARIES + mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) + endif() + if (UNIX) + # Fix me: complete the path on linux + find_path(MKL_HOME_PLATFORM mkl + HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + PATH_SUFFIXES linux) + set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) + set(MKL_LINK_DIRS) # incompleted + set(MKL_REDIST_SO_DIRS) # incompleted + endif() + include_directories(${MKL_INCLUDE_DIRS}) + link_directories(${MKL_LINK_DIRS}) + if (tensorflow_ENABLE_MKLDNN_SUPPORT) + include(mkldnn) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + include_directories(${mkldnn_INCLUDE_DIRS}) + endif() +endif (tensorflow_ENABLE_MKL_SUPPORT) + if (tensorflow_ENABLE_GPU) if (NOT WIN32) # Default install paths for cuda libraries in Linux @@ -341,7 +399,8 @@ if (tensorflow_ENABLE_GPU) if(NOT CUDNN_HOME) set(CUDNN_HOME ${CUDA_TOOLKIT_TARGET_DIR}) endif(NOT CUDNN_HOME) - include_directories(${CUDNN_HOME}) + set(CUDNN_INCLUDE "${CUDNN_HOME}/include") + set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib) else (WIN32) @@ -369,10 +428,10 @@ if (tensorflow_ENABLE_GPU) message("culibos-static: ${culibos_STATIC_LIBRARY}") endif (NOT culibos_STATIC_LIBRARY) - include_directories(${CUDNN_INCLUDE}) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) @@ -388,31 +447,22 @@ if (tensorflow_ENABLE_GPU) "#endif // CUDA_CUDA_CONFIG_H_\n" ) - if (WIN32) - # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk - # installs them under cuda/version/include and to avoid that we need to change tf we copy a - # few files to cuda/include - FILE(COPY - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_HOME}/include/cudnn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h - DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include - ) - else(WIN32) - # Linux has slightly differnt install paths than Windows - FILE(COPY - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_INCLUDE}/cudnn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include - ) - endif(WIN32) + # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk + # installs them under cuda/version/include and to avoid that we need to change tf we copy a + # few files to cuda/include + FILE(COPY + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h + ${CUDNN_INCLUDE}/cudnn.h + DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include + ) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 8f85a75ee466dbac524a1266dc2522109ca77cd5..0b79f718d4823a987e02804f59a432ee46d0ada3 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -26,7 +26,7 @@ The CMake files in this directory can build the core TensorFlow runtime, an example C++ binary, and a PIP package containing the runtime and Python bindings. -### Pre-requisites +### Prerequisites * CMake version 3.5 or later. @@ -34,14 +34,16 @@ bindings. * [SWIG](http://www.swig.org/download.html) -* Additional pre-requisites for Microsoft Windows: +* Additional prerequisites for Microsoft Windows: - Visual Studio 2015 - Python 3.5 - - NumPy 1.11.0 or later -* Additional pre-requisites for Linux: +* Additional prerequisites for Linux: - Python 2.7 or later - [Docker](https://www.docker.com/) (for automated testing) + +* Python dependencies: + - wheel - NumPy 1.11.0 or later ### Known-good configurations @@ -102,7 +104,7 @@ ops or APIs. Step-by-step Windows build ========================== -1. Install the pre-requisites detailed above, and set up your environment. +1. Install the prerequisites detailed above, and set up your environment. * The following commands assume that you are using the Windows Command Prompt (`cmd.exe`). You will need to set up your environment to use the @@ -126,6 +128,18 @@ Step-by-step Windows build D:\local\cuda\bin ``` + * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable. + + In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable. + It should contain the directory of the MKL dlls. For example: + + ``` + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt + ``` + + * We assume that `cmake` and `git` are installed and in your `%PATH%`. If for example `cmake` is not in your path and it is installed in `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory @@ -164,7 +178,15 @@ Step-by-step Windows build More? -Dtensorflow_ENABLE_GPU=ON ^ More? -DCUDNN_HOME="D:\...\cudnn" ``` + To build with MKL support add "^" at the end of the last line above following with: + + ``` + More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ + More? -DMKL_HOME="D:\...\compilers_and_libraries" + ``` + To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: + ``` More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX ``` @@ -224,6 +246,7 @@ Step-by-step Windows build ``` ctest -C RelWithDebInfo ``` + * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. After building the python wheel, you need to install the new wheel before running the tests. @@ -232,6 +255,12 @@ Step-by-step Windows build ctest -C RelWithDebInfo ``` + * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). + CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl. + + * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. + + 4. Invoke MSBuild to build TensorFlow. To build the C++ example program, which will be created as a `.exe` @@ -249,6 +278,7 @@ Step-by-step Windows build D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj ``` + Linux Continuous Integration build ================================== diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index 836889895567f679d9960e29ece1600d1a7a58eb..98a8c7e736e5c8c407b90e8eac440cdc7ab21579 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) -set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip) +set(cub_HASH SHA256=6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index a235442dc5c0a07e249653381436eeae81575883..cdaa6b73b93666d272faacb869e8272561a2c74c 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) -set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip) +set(gemmlowp_HASH SHA256=b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index a9f43a3ecba4830533efcc13f8c4c1c61fe1ef78..693dc7cd673233b889b35a3f3170b57581da9a9f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) +set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") @@ -35,6 +35,8 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a639fdee367f060d4c8a79267803da6ffe3dc503 --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(mkldnn_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/include) +set(mkldnn_URL https://github.com/01org/mkl-dnn.git) +set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src) +set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) + +if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + else() + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + endif() +else() + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) +endif() + +ExternalProject_Add(mkldnn + PREFIX mkldnn + GIT_REPOSITORY ${mkldnn_URL} + GIT_TAG ${mkldnn_TAG} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${mkldnn_STATIC_LIBRARIES} + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DMKLINC:STRING=${MKL_INCLUDE_DIRS} +) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index f3a37ff5088e3f9e54e38c0edb5777c27b26969f..b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) +set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 6cd66a65990e7a2b963b52b310061b551752cd4d..ad2af01bc002555ce48f8b9bfb7d8d724a1a7dc8 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -15,32 +15,33 @@ include (ExternalProject) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) -set(png_URL https://storage.googleapis.com/libpng-public-archive/libpng-1.2.53.tar.gz) -set(png_HASH SHA256=e05c9056d7f323088fd7824d8c6acc03a4a758c4b4916715924edc5dd3223a72) +set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) +set(png_HASH SHA256=e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef) set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(png_STATIC_LIBRARIES - debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib - optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_staticd.lib + optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_static.lib) else() if(CMAKE_BUILD_TYPE EQUAL Debug) set(png_STATIC_LIBRARIES - ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib) + ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_staticd.lib) else() set(png_STATIC_LIBRARIES - ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_static.lib) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) endif() set(png_HEADERS - "${png_INSTALL}/include/libpng12/png.h" - "${png_INSTALL}/include/libpng12/pngconf.h" + "${png_INSTALL}/include/libpng16/png.h" + "${png_INSTALL}/include/libpng16/pngconf.h" + "${png_INSTALL}/include/libpng16/pnglibconf.h" ) ExternalProject_Add(png diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index aba8a5244e17d717293deec6d9b6e8e725ef010e..ab464bc99a43138130bb2758ae28ecef29805c31 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a) +set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 57c4ae76517e4d7247093edd5e5bd95a83258d87..7f835d2d519273a6d52d12f92ed585a4ddbeb973 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) -set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) -set(sqlite_HASH SHA256=208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4) +set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip) +set(sqlite_HASH SHA256=4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc) set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index aaae18a313dd082b428654091c9411600c981ec9..6f059c7225dd0938b758e8f9c28ec36fcff6db4c 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -42,7 +42,6 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/c++11") add_definitions ("-DNSYNC_USE_CPP11_TIMEPOINT -DNSYNC_ATOMIC_CPP11") set (NSYNC_OS_CPP_SRC - "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/c++11/src/per_thread_waiter.cc" "platform/c++11/src/yield.cc" "platform/c++11/src/time_rep_timespec.cc" @@ -52,6 +51,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/win32") add_compile_options ("/TP") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/win32/src/clock_gettime.c" "platform/win32/src/pthread_key_win32.cc" ${NSYNC_OS_CPP_SRC} @@ -68,6 +68,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/c++11/src/nsync_semaphore_mutex.cc" "platform/posix/src/clock_gettime.c" "platform/posix/src/nsync_semaphore_mutex.c" ) @@ -75,9 +76,11 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") "platform/posix/src/start_thread.c" ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") + include_directories (BEFORE "${PROJECT_SOURCE_DIR}/platform/c++11.futex") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/linux/src/nsync_semaphore_futex.c" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -87,6 +90,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -96,6 +100,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC @@ -105,6 +110,7 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC + "platform/c++11/src/nsync_semaphore_mutex.cc" ${NSYNC_OS_CPP_SRC} ) set (NSYNC_TEST_OS_SRC diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index bfe53c01b3b5fb9db8a5d8fa280d1d7f98974882..91839194c7c214fe910ff78723ab418f86c7fac0 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -79,9 +79,11 @@ tensorflow/python/keras/_impl/keras/preprocessing tensorflow/python/keras/_impl/keras/utils tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions tensorflow/python/kernel_tests/linalg tensorflow/python/kernel_tests/random +tensorflow/python/kernel_tests/testdata tensorflow/python/layers tensorflow/python/lib tensorflow/python/lib/core @@ -102,6 +104,8 @@ tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf tensorflow/tools +tensorflow/tools/api +tensorflow/tools/api/generator tensorflow/tools/graph_transforms tensorflow/contrib tensorflow/contrib/all_reduce @@ -147,8 +151,6 @@ tensorflow/contrib/crf tensorflow/contrib/crf/python tensorflow/contrib/crf/python/ops tensorflow/contrib/cudnn_rnn -tensorflow/contrib/cudnn_rnn/kernels -tensorflow/contrib/cudnn_rnn/ops tensorflow/contrib/cudnn_rnn/python tensorflow/contrib/cudnn_rnn/python/layers tensorflow/contrib/cudnn_rnn/python/ops @@ -160,6 +162,9 @@ tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto tensorflow/contrib/deprecated +tensorflow/contrib/distribute +tensorflow/contrib/distribute/python +tensorflow/contrib/distribute/python/examples tensorflow/contrib/distributions tensorflow/contrib/distributions/python tensorflow/contrib/distributions/python/ops @@ -331,6 +336,7 @@ tensorflow/contrib/nccl/kernels tensorflow/contrib/nccl/ops tensorflow/contrib/nccl/python tensorflow/contrib/nccl/python/ops +tensorflow/contrib/nearest_neighbor tensorflow/contrib/nearest_neighbor/kernels tensorflow/contrib/nearest_neighbor/ops tensorflow/contrib/nearest_neighbor/python @@ -341,6 +347,7 @@ tensorflow/contrib/nn/python/ops tensorflow/contrib/opt tensorflow/contrib/opt/python tensorflow/contrib/opt/python/training +tensorflow/contrib/optimizer_v2 tensorflow/contrib/pi_examples tensorflow/contrib/pi_examples/camera tensorflow/contrib/pi_examples/label_image @@ -349,6 +356,9 @@ tensorflow/contrib/periodic_resample tensorflow/contrib/periodic_resample/python tensorflow/contrib/periodic_resample/python/ops tensorflow/contrib/predictor +tensorflow/contrib/proto +tensorflow/contrib/proto/python +tensorflow/contrib/proto/python/ops tensorflow/contrib/quantization tensorflow/contrib/quantization/python tensorflow/contrib/quantize @@ -357,6 +367,10 @@ tensorflow/contrib/receptive_field tensorflow/contrib/receptive_field/python tensorflow/contrib/receptive_field/python/util tensorflow/contrib/receptive_field/python/util/examples +tensorflow/contrib/recurrent +tensorflow/contrib/recurrent/python +tensorflow/contrib/recurrent/python/ops +tensorflow/contrib/recurrent/python/kernel_tests tensorflow/contrib/reduce_slice_ops tensorflow/contrib/reduce_slice_ops/kernels tensorflow/contrib/reduce_slice_ops/ops @@ -377,6 +391,9 @@ tensorflow/contrib/rnn/ops tensorflow/contrib/rnn/python tensorflow/contrib/rnn/python/kernel_tests tensorflow/contrib/rnn/python/ops +tensorflow/contrib/rpc +tensorflow/contrib/rpc/python +tensorflow/contrib/rpc/python/ops tensorflow/contrib/saved_model tensorflow/contrib/saved_model/python tensorflow/contrib/saved_model/python/saved_model diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index 8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57..d63c41db844af243f0c6600b1565635ac9b91cac 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -1,4 +1,5 @@ tensorflow/core +tensorflow/core/kernels/boosted_trees tensorflow/core/profiler tensorflow/python tensorflow/contrib/boosted_trees/proto diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c index 968ab13a0c43793341431248713f81010c87f148..9e355da33a7258119b6086216f5487d7ea94716c 100644 --- a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c @@ -1,3 +1,18 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // This is a program to test if compiler is compatible with CUDA. #define __CUDACC__ #include "crt/host_config.h" diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc index 968ab13a0c43793341431248713f81010c87f148..beb574061bea8d04af8386223749677ae36a5d9b 100644 --- a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc @@ -1,3 +1,18 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================*/ + // This is a program to test if compiler is compatible with CUDA. #define __CUDACC__ #include "crt/host_config.h" diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 96ac60d095dbc84470ff1be92f4bf52bb420fc52..a54cbff33b66d63d7229fa2f50b8a4ca962111ed 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,6 +63,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc" ) +file(GLOB_RECURSE tf_core_cpu_whitelisted_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc" +) +list(REMOVE_ITEM tf_core_cpu_exclude_srcs ${tf_core_cpu_whitelisted_srcs}) list(REMOVE_ITEM tf_core_cpu_srcs ${tf_core_cpu_exclude_srcs}) if (tensorflow_ENABLE_GPU) @@ -79,6 +85,7 @@ if (tensorflow_ENABLE_GPU) "${tensorflow_source_dir}/tensorflow/core/*test*.cc" ) list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_gpu_exclude_srcs}) + list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_cpu_whitelisted_srcs}) list(APPEND tf_core_cpu_srcs ${tf_core_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index a1c320347fe60f87806736befc677541a93e7e93..b47c32f1c48b3d42fe5b4ba115cc2a511b7ee5f4 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -276,7 +276,7 @@ add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) add_custom_command(OUTPUT ${VERSION_INFO_CC} COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py - --raw_generate ${VERSION_INFO_CC} + ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} DEPENDS __force_rebuild) set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) @@ -341,9 +341,3 @@ add_dependencies(tf_core_framework tf_core_lib proto_text ) - -if(WIN32) - # Cmake > 3.6 will quote this as -D"__VERSION__=\"MSVC\"" which nvcc fails on. - # Instead of defining this global, limit it to tf_core_framework where its used. - target_compile_definitions(tf_core_framework PRIVATE __VERSION__="MSVC") -endif() diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f219d5eb577afa9edaadca09aef9869c81d2bd87..ed018b4fed8e47632f632723f19cc755f2079f86 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -67,10 +67,10 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 59e094812aaf4da2549d96314fc550e5635f9de8..e558691de4b74988031f7b2204aad92e8c7af68b 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -15,19 +15,23 @@ set(tf_op_lib_names "audio_ops" "array_ops" - "batch_ops" + "batch_ops" "bitwise_ops" + "boosted_trees_ops" "candidate_sampling_ops" "checkpoint_ops" "control_flow_ops" "ctc_ops" + "cudnn_rnn_ops" "data_flow_ops" "dataset_ops" + "decode_proto_ops" + "encode_proto_ops" "functional_ops" "image_ops" "io_ops" "linalg_ops" - "list_ops" + "list_ops" "lookup_ops" "logging_ops" "manip_ops" @@ -38,6 +42,7 @@ set(tf_op_lib_names "random_ops" "remote_fused_graph_ops" "resource_variable_ops" + "rpc_ops" "script_ops" "sdca_ops" "set_ops" @@ -47,7 +52,7 @@ set(tf_op_lib_names "state_ops" "stateless_random_ops" "string_ops" - "summary_ops" + "summary_ops" "training_ops" ) @@ -84,7 +89,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index b730ebd3baacafe8ae401e8987104f3062372954..c4bdb69d828b269e6246777e74c3756ba1c4b96f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -319,6 +319,7 @@ GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") +GENERATE_PYTHON_OP_LIB("boosted_trees_ops") GENERATE_PYTHON_OP_LIB("math_ops") GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") @@ -326,8 +327,13 @@ GENERATE_PYTHON_OP_LIB("checkpoint_ops") GENERATE_PYTHON_OP_LIB("control_flow_ops" ADDITIONAL_LIBRARIES $) GENERATE_PYTHON_OP_LIB("ctc_ops") +GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") +GENERATE_PYTHON_OP_LIB("decode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) +GENERATE_PYTHON_OP_LIB("encode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") @@ -341,6 +347,8 @@ GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py) GENERATE_PYTHON_OP_LIB("resource_variable_ops") +GENERATE_PYTHON_OP_LIB("rpc_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") GENERATE_PYTHON_OP_LIB("set_ops") @@ -348,6 +356,7 @@ GENERATE_PYTHON_OP_LIB("state_ops") GENERATE_PYTHON_OP_LIB("sparse_ops") GENERATE_PYTHON_OP_LIB("spectral_ops") GENERATE_PYTHON_OP_LIB("string_ops") +GENERATE_PYTHON_OP_LIB("summary_ops") GENERATE_PYTHON_OP_LIB("user_ops") GENERATE_PYTHON_OP_LIB("training_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/training/gen_training_ops.py) @@ -366,8 +375,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_coder_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" @@ -419,8 +426,6 @@ GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) -GENERATE_PYTHON_OP_LIB("summary_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) @@ -475,6 +480,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor_bridge.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_exception_registry.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" @@ -547,12 +554,13 @@ if(WIN32) set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow.def") endif() set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE) - + math(EXPR tensorflow_target_bitness "${CMAKE_SIZEOF_VOID_P}*8") add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py --input "${pywrap_tensorflow_internal_static_dependencies}" --output "${pywrap_tensorflow_deffile}" --target _pywrap_tensorflow_internal.pyd + --bitness "${tensorflow_target_bitness}" BYPRODUCTS ${pywrap_tensorflow_deffile} # Required for Ninja ) endif(WIN32) @@ -582,6 +590,12 @@ add_library(pywrap_tensorflow_internal SHARED ${pywrap_tensorflow_deffile} ) +# There is a bug in GCC 5 resulting in undefined reference to a __cpu_model function when +# linking to the tensorflow library. Adding the following libraries fixes it. +if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0) + target_link_libraries(pywrap_tensorflow_internal PRIVATE gcc_s gcc) +endif() + if(WIN32) add_dependencies(pywrap_tensorflow_internal pywrap_tensorflow_internal_static) endif(WIN32) @@ -685,6 +699,77 @@ AddUserOps(TARGET _beam_search_ops DEPENDS pywrap_tensorflow_internal tf_python_ops DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/) +if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + else() + add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + endif() +else() + add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) +endif() + + +######################################################## +# Generate API __init__.py files. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}") + endif() +endforeach(api_init_file) +set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") +file(WRITE "${api_init_list_file}" "${api_init_files}") + +# Run create_python_api.py to generate __init__.py files. +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}" + + # Re-add tensorflow/__init__.py back. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" +) + +add_custom_target(tf_python_api SOURCES ${api_init_files}) +add_dependencies(tf_python_api tf_python_ops) + + ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -694,6 +779,7 @@ add_dependencies(tf_python_build_pip_package tf_python_copy_scripts_to_destination tf_python_touchup_modules tf_python_ops + tf_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. @@ -706,25 +792,6 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) -if(WIN32) - if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) - else() - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) - endif() -else() - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so) -endif() add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 6d36d5fc5c2854b2d7d2542a3cb12e033e193b88..38f40452b533fdc0dba6ac686a0ff43a2ef13cb8 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -52,12 +52,13 @@ if(WIN32) set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/tensorflow.def") endif() set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) - + math(EXPR tensorflow_target_bitness "${CMAKE_SIZEOF_VOID_P}*8") add_custom_command(TARGET tensorflow_static POST_BUILD COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py --input "${tensorflow_static_dependencies}" --output "${tensorflow_deffile}" --target tensorflow.dll + --bitness "${tensorflow_target_bitness}" ) endif(WIN32) @@ -100,8 +101,7 @@ if(WIN32) endif(WIN32) target_include_directories(tensorflow PUBLIC - $ - $) + $) install(TARGETS tensorflow EXPORT tensorflow_export RUNTIME DESTINATION bin @@ -133,10 +133,6 @@ install(DIRECTORY ${tensorflow_source_dir}/tensorflow/stream_executor/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src/google/ DESTINATION include/google FILES_MATCHING PATTERN "*.h") -# nsync headers -install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/ - DESTINATION include/external/nsync - FILES_MATCHING PATTERN "*.h") # Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/ DESTINATION include/Eigen) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 91ca33f4c4d5f6c822f45b0676e6e46d2e4c2860..af48ef1fd40456162fee8b1e2c3ca45ecdb58830 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -65,6 +65,12 @@ if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" ) + if (NOT tensorflow_BUILD_CC_TESTS) + file(GLOB tf_stream_executor_gpu_tests + "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*_test.cc" + ) + list(REMOVE_ITEM tf_stream_executor_gpu_srcs ${tf_stream_executor_gpu_tests}) + endif() list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 1c4ebd7f0c1113bcd0857fb0858df2248499f920..92f2ab6dea8e7da5dd8481639eda24e31c06848f 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -195,9 +195,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py" # requires scipy "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py" # Takes very long to run without sharding (defined in bazel build file). "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" # Loading resources in contrib doesn't seem to work on Windows @@ -208,6 +210,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py" # Test is flaky on Windows GPU builds (b/38283730). "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" + # Disable following manual tag in BUILD. + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + ) if (WIN32) set(tf_test_src_py_exclude @@ -222,6 +227,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/grpc_large_data_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different @@ -278,6 +284,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py" # Segfaults on Windows. # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. @@ -475,6 +482,10 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*_test.cc" ) + list(REMOVE_ITEM tf_test_src_simple + ${tf_core_profiler_test_srcs} + ) + set(tf_test_lib tf_test_lib) add_library(${tf_test_lib} STATIC ${tf_src_testlib}) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index 53c2285699a6ca94e1e6b147080338b507f4d768..cffe069aa352f8a6f2c436bc70b62f54e2336ac6 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -63,7 +63,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" r"^(TFE_\w*)$|" r"tensorflow::|" r"functor::|" - r"nsync_|" + r"\?nsync_|" r"perftools::gputools") # We want to identify data members explicitly in the DEF file, so that no one @@ -87,6 +87,7 @@ def get_args(): required=True) parser.add_argument("--output", help="output deffile", required=True) parser.add_argument("--target", help="name of the target", required=True) + parser.add_argument("--bitness", help="build target bitness", required=True) args = parser.parse_args() return args @@ -125,7 +126,10 @@ def main(): # Header for the def file. def_fp.write("LIBRARY " + args.target + "\n") def_fp.write("EXPORTS\n") - def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n") + if args.bitness == "64": + def_fp.write("\t??1OpDef@tensorflow@@UEAA@XZ\n") + else: + def_fp.write("\t??1OpDef@tensorflow@@UAE@XZ\n") # Each symbols returned by undname matches the same position in candidates. # We compare on undname but use the decorated name from candidates. diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index ec3d550b70d2aaa23b989c44f3d86fa87cffb335..9ca4ce8a9c765677865f77ea4982ad8613ce334c 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -92,6 +92,34 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "pmf_to_cdf_op", + srcs = ["kernels/pmf_to_cdf_op.cc"], + visibility = ["//visibility:public"], + deps = [ + ":coder_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "pmf_to_cdf_op_test", + size = "small", + srcs = ["kernels/pmf_to_cdf_op_test.cc"], + deps = [ + ":pmf_to_cdf_op", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + cc_library( name = "all_ops", deps = [":coder_ops_op_lib"], @@ -99,12 +127,16 @@ cc_library( cc_library( name = "all_kernels", - deps = [":range_coder_ops"], + deps = [ + ":pmf_to_cdf_op", + ":range_coder_ops", + ], ) tf_custom_op_library( name = "python/ops/_coder_ops.so", srcs = [ + "kernels/pmf_to_cdf_op.cc", "kernels/range_coder.cc", "kernels/range_coder.h", "kernels/range_coder_ops.cc", @@ -154,14 +186,3 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c787e8edede0942cd152eafa6333849d194e58b6 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc @@ -0,0 +1,150 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { +using errors::InvalidArgument; + +class PmfToCdfOp : public OpKernel { + public: + explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); + OP_REQUIRES( + context, 0 < precision_ && precision_ <= 16, + InvalidArgument("`precision` must be in [1, 16]: ", precision_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& pmf_tensor = context->input(0); + + TensorShape shape = pmf_tensor.shape(); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape), + InvalidArgument("`pmf` should be at least 1-D.")); + OP_REQUIRES( + context, shape.dim_size(shape.dims() - 1) > 1, + InvalidArgument("`pmf` size should be at least 2 in the last axis.")); + shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1); + + Tensor* cdf_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor)); + + auto pmf = pmf_tensor.flat_inner_dims(); + auto cdf = cdf_tensor->flat_inner_dims(); + CHECK_EQ(pmf.dimension(0), cdf.dimension(0)); + CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); + + const double n = pmf.dimension(1); + const int64 cost_per_unit = static_cast(50.0 * n * std::log2(n)); + thread::ThreadPool* thread_pool = + context->device()->tensorflow_cpu_worker_threads()->workers; + thread_pool->ParallelFor( + pmf.dimension(0), cost_per_unit, + [this, pmf, &cdf](int64 start, int64 limit) { + const gtl::ArraySlice::size_type pmf_size = pmf.dimension(1); + for (int64 i = start; i < limit; ++i) { + cdf(i, 0) = 0; + PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size}); + } + }); + } + + private: + struct Item { + Item(int32* p, double mass) : pointer(p), mass(mass) { + penalty = ComputeNextPenalty(); + } + + void Decrease() { + CHECK_GT(*pointer, 1); + --*pointer; + penalty = ComputeNextPenalty(); + } + + friend bool operator<(const Item& lhs, const Item& rhs) { + return lhs.penalty < rhs.penalty; + } + + double ComputeNextPenalty() { + if (*pointer <= 1) { + return std::numeric_limits::infinity(); + } + return mass * (std::log2(*pointer) - std::log2(*pointer - 1)); + } + + int32* pointer; + double mass; + double penalty; + }; + + void PerShard(gtl::ArraySlice pmf, + gtl::MutableArraySlice cdf) const { + CHECK_EQ(pmf.size(), cdf.size()); + + const int32 normalizer = 1 << precision_; + std::transform(pmf.begin(), pmf.end(), cdf.begin(), + [normalizer](float mass) { + int32 value = std::rint(mass * normalizer); + // NOTE: Consider checking if mass > 0. + value = std::max(value, 1); + return value; + }); + + int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0); + if (sum > normalizer) { + std::vector queue; + queue.reserve(cdf.size()); + for (int i = 0; i < cdf.size(); ++i) { + queue.emplace_back(&cdf[i], pmf[i]); + } + + std::sort(queue.begin(), queue.end()); + while (sum-- > normalizer) { + queue[0].Decrease(); + // Performs a linear search because this find_if is likely to return + // iterator very close to the begin. + auto iter = + std::find_if(std::next(queue.begin()), queue.end(), + [&queue](const Item& rhs) { return queue[0] < rhs; }); + std::rotate(queue.begin(), std::next(queue.begin()), iter); + } + } + std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); + } + + int precision_; +}; + +REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU), + PmfToCdfOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c70e38faab713e23b5defa890d35bfadeac5940a --- /dev/null +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +class PmfToQuantizedCdfOpTest : public OpsTestBase { + protected: + void SetupOp(int precision, Tensor* input) { + TF_ASSERT_OK(NodeDefBuilder("pmf_to_cdf", "PmfToQuantizedCdf") + .Input(FakeInput(DT_FLOAT)) + .Attr("precision", precision) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + inputs_.clear(); + inputs_.emplace_back(input); + } + + void GenerateData(random::SimplePhilox* rand, + gtl::MutableArraySlice slice) { + constexpr float minimum = std::numeric_limits::epsilon(); + float sum = 0; + for (float& value : slice) { + value = std::max(rand->RandFloat(), minimum); + sum += value; + } + for (float& value : slice) { + value /= sum; + } + } + + void Verify(int precision, const Tensor& pmf_tensor, + const Tensor& cdf_tensor) { + ASSERT_EQ(pmf_tensor.dims(), cdf_tensor.dims()); + const int n = pmf_tensor.dims(); + + for (int i = 0; i < n - 1; ++i) { + EXPECT_EQ(pmf_tensor.dim_size(i), cdf_tensor.dim_size(i)); + } + + auto pmf = pmf_tensor.flat_inner_dims(); + auto cdf = cdf_tensor.flat_inner_dims(); + EXPECT_EQ(pmf.dimension(1) + 1, cdf.dimension(1)); + + const int normalizer = 1 << precision; + for (int i = 0; i < pmf.dimension(0); ++i) { + EXPECT_EQ(0, cdf(i, 0)); + + TTypes::UnalignedConstVec cdf_slice(&cdf(i, 0), cdf.dimension(1)); + + for (int j = 1; j < cdf_slice.size(); ++j) { + const int32 diff = cdf_slice(j) - cdf_slice(j - 1); + EXPECT_GT(diff, 0); + } + + EXPECT_LE(cdf_slice(cdf_slice.size() - 1), normalizer); + } + } +}; + +TEST_F(PmfToQuantizedCdfOpTest, UnderSum) { + Tensor pmf(DT_FLOAT, {1, 10, 1, 32}); + auto matrix = pmf.flat_inner_dims(); + const std::size_t n = matrix.dimension(1); + + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + for (int64 i = 0; i < matrix.dimension(0); ++i) { + GenerateData(&rand, {&matrix(i, 0), n}); + } + + constexpr int kPrecision = 10; + SetupOp(kPrecision, &pmf); + TF_ASSERT_OK(RunOpKernel()); + + Verify(kPrecision, pmf, *GetOutput(0)); +} + +TEST_F(PmfToQuantizedCdfOpTest, OverSum) { + Tensor pmf(DT_FLOAT, {10, 1, 1, 100}); + auto matrix = pmf.flat_inner_dims(); + + // Half of each PMF is filled with zeros. The op will round up zeros to ones, + // post quantization. These round ups are likely to make the sum over + // normalizer value. + matrix.setZero(); + const std::size_t n = matrix.dimension(1) / 2; + + random::PhiloxRandom gen; + random::SimplePhilox rand(&gen); + for (int64 i = 0; i < matrix.dimension(0); ++i) { + GenerateData(&rand, {&matrix(i, 0), n}); + } + + constexpr int kPrecision = 7; + SetupOp(kPrecision, &pmf); + TF_ASSERT_OK(RunOpKernel()); + + Verify(kPrecision, pmf, *GetOutput(0)); +} + +TEST_F(PmfToQuantizedCdfOpTest, ShapeFn) { + ShapeInferenceTestOp op("PmfToQuantizedCdf"); + + INFER_OK(op, "?", "?"); + INFER_OK(op, "[3]", "[4]"); + INFER_OK(op, "[3,4]", "[d0_0,5]"); + INFER_OK(op, "[3,4,5]", "[d0_0,d0_1,6]"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc index 9056d1a6963d7be92f499db31385fb6afe2dc515..9bb171298f85088fdb776302776f2ba379b4f52e 100644 --- a/tensorflow/contrib/coder/ops/coder_ops.cc +++ b/tensorflow/contrib/coder/ops/coder_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; @@ -115,5 +116,36 @@ decoded: An int32 tensor with shape equal to `shape`. precision: The number of bits for probability quantization. Must be <= 16, and must match the precision used by RangeEncode that produced `encoded`. )doc"); + +REGISTER_OP("PmfToQuantizedCdf") + .Input("pmf: float") + .Output("cdf: int32") + .Attr("precision: int >= 1") + .SetShapeFn([] (InferenceContext* c) { + ShapeHandle in; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in)); + DimensionHandle last; + TF_RETURN_IF_ERROR(c->Add(c->Dim(in, -1), 1, &last)); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->ReplaceDim(in, -1, last, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Converts PMF to quantized CDF. This op uses floating-point operations +internally. Therefore the quantized output may not be consistent across multiple +platforms. For entropy encoders and decoders to have the same quantized CDF on +different platforms, the quantized CDF should be produced once and saved, then +the saved quantized CDF should be used everywhere. + +After quantization, if PMF sums to less than or equal to 2^precision, then this +is equivalent to cumsum over the last dimension. This op makes no effort to make +the sum close to 2^precision when the sum is already <= 2^precision. + +After quantization, if PMF sums to greater than 2^precision, then some values of +PMF is decreased to keep the sum no more than 2^precision. + +Note that the input PMF is pre-quantization. +)doc"); // clang-format on } // namespace tensorflow diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index 388d8e6ed6d9cb9400b0bfbe8e3f50b80149ea1a..bcee0b04c8430588c2dcbc199504bede0436f8f1 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -46,15 +46,3 @@ cuda_py_test( ], xla_enabled = True, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/copy_graph/BUILD b/tensorflow/contrib/copy_graph/BUILD index 8ec706df74e2c91345c4bf7a506fdb424a996773..fa44c4d54e1ee871feb425115525b1cf8b732214 100644 --- a/tensorflow/contrib/copy_graph/BUILD +++ b/tensorflow/contrib/copy_graph/BUILD @@ -41,15 +41,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index b806799202bff4f2f6dbf717fbeea74a04b8cd6e..102bc460fdadb0ad5dc9a2960b8655c55357108e 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -201,7 +201,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it #stores String-based info such as name, device and type of the op. #Unique to every Operation instance. - new_node_def = deepcopy(op._node_def) + new_node_def = deepcopy(op.node_def) #Change the name new_node_def.name = new_name @@ -211,7 +211,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): #Make a copy of the op_def too. #Its unique to every _type_ of Operation. - op_def = deepcopy(op._op_def) + op_def = deepcopy(op.op_def) #Initialize a new Operation instance new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD index 7aad4abdb908d0284b85137bff842bd0f38d09c6..5c1a17df4f95f3c4d05b286de0e3d7b009a76bd7 100644 --- a/tensorflow/contrib/crf/BUILD +++ b/tensorflow/contrib/crf/BUILD @@ -40,15 +40,3 @@ cuda_py_tests( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 721dc4d0801d1f0e116921888e3851a95e0b72b0..a5e065b93a23c3dd2838d81e7cf537dec226f4f9 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -281,6 +281,21 @@ class CrfTest(test.TestCase): self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), expected_max_sequence[:sequence_lengths]) + def testCrfDecodeZeroSeqLength(self): + """ + Test that crf_decode works when sequence_length contains one or more zeros. + """ + with self.test_session() as sess: + inputs = constant_op.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = constant_op.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = constant_op.constant(np.zeros([2], + dtype=np.int32)) + values = crf.crf_decode(inputs, transition_params, sequence_lengths) + tags, scores = sess.run(values) + self.assertEqual(len(tags.shape), 2) + self.assertEqual(len(scores.shape), 1) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index faa78769b98699af59047aed2865771120110fc2..e37c029cebf30eba59c560bc00ed73d2eea86213 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -105,8 +105,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, return utils.smart_cond( pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], 1), - fn1=_single_seq_fn, - fn2=_multi_seq_fn) + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -479,15 +479,17 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + # sequence length is not allowed to be less than zero + sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, - sequence_length=sequence_length - 1, + sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] - backpointers, sequence_length - 1, seq_dim=1) + backpointers, sequence_length_less_one, seq_dim=1) # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) @@ -497,7 +499,7 @@ def crf_decode(potentials, transition_params, sequence_length): decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] crf_bwd_cell, inputs=backpointers, - sequence_length=sequence_length - 1, + sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) @@ -511,7 +513,7 @@ def crf_decode(potentials, transition_params, sequence_length): return decode_tags, best_score return utils.smart_cond( - pred=math_ops.equal( - potentials.shape[1].value or array_ops.shape(potentials)[1], 1), - fn1=_single_seq_fn, - fn2=_multi_seq_fn) + pred=math_ops.equal(potentials.shape[1].value or + array_ops.shape(potentials)[1], 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fec358c4e1067dc8dc8173d1b9d05dc90b90ca05..d68015ae1565b778b1ba0744f515d09007175e93 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -9,52 +9,10 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -tf_custom_op_library( - name = "python/ops/_cudnn_rnn_ops.so", - srcs = [ - "kernels/cudnn_rnn_ops.cc", - "ops/cudnn_rnn_ops.cc", - ], - deps = [ - "//tensorflow/core/kernels:bounds_check_lib", - "@farmhash_archive//:farmhash", - ], -) - -tf_kernel_library( - name = "cudnn_rnn_kernels", - srcs = ["kernels/cudnn_rnn_ops.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", - "//third_party/eigen3", - "@farmhash_archive//:farmhash", - ], -) - -tf_gen_op_libs( - op_lib_names = ["cudnn_rnn_ops"], - deps = [ - "//tensorflow/core:lib", - ], -) - -tf_gen_op_wrapper_py( - name = "cudnn_rnn_ops", - deps = [":cudnn_rnn_ops_op_lib"], -) tf_custom_op_py_library( name = "cudnn_rnn_py", @@ -64,20 +22,14 @@ tf_custom_op_py_library( "python/layers/cudnn_rnn.py", "python/ops/cudnn_rnn_ops.py", ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":cudnn_rnn_ops", + "//tensorflow/contrib/eager/python:checkpointable_utils", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:cudnn_rnn_ops_gen", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", @@ -172,32 +124,3 @@ cuda_py_test( "requires_cudnn5", ], ) - -tf_cc_test( - name = "cudnn_rnn_ops_test_cc", - size = "small", - srcs = [ - "ops/cudnn_rnn_ops_test.cc", - ], - deps = [ - ":cudnn_rnn_ops_op_lib", - "//tensorflow/core", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 9897c31a98e0b335c18a84825fc518ed1fc310a2..6fb56b0858786662546ecab425b1a2564fbd9a64 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import argparse import collections +import functools import itertools import os import sys @@ -34,7 +35,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_nn_ops @@ -53,6 +54,7 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import adagrad from tensorflow.python.training import adam +from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop @@ -265,7 +267,7 @@ def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): return outputs, (output_state_fw, output_state_bw) -class CudnnRNNTestBasic(TensorFlowTestCase): +class CudnnRNNTestBasic(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @@ -467,7 +469,7 @@ class CudnnRNNTestBasic(TensorFlowTestCase): # TODO(jamesqin): Transform to parameterized test after it is included in the # TF open source codebase. -class CudnnRNNTestSaveRestore(TensorFlowTestCase): +class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase): def _CompareWeights(self, lhs, rhs): self.assertEqual(len(lhs), len(rhs)) @@ -701,9 +703,146 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): self._TestSaveRestoreHelper(CUDNN_RNN_RELU) +class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): + + def _VerifyCheckpoint( + self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn, + num_layers, input_size, expected_variable_values, num_applications=3): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with ops.device("gpu:0"): + cudnn_layer = cudnn_cell_fn() + cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer) + status = cudnn_checkpoint.restore(checkpoint_path) + inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], + dtype=dtypes.float32) + cudnn_output, _ = cudnn_layer(inputs) + status.assert_consumed().run_restore_ops() + second_save_path = cudnn_checkpoint.save(checkpoint_prefix) + restore_layer = compatible_cell_fn() + restore_layer_checkpoint = checkpointable_utils.Checkpoint( + cell=restore_layer) + status = restore_layer_checkpoint.restore(second_save_path) + current_state = restore_layer.zero_state(1, dtypes.float32) + for _ in range(num_applications): + restore_layer_output, current_state = restore_layer( + inputs=3. * array_ops.ones([1, input_size]), + state=current_state) + status.assert_consumed().run_restore_ops() + self.assertTrue(restore_layer.variables) + for variable, expected_value in zip( + restore_layer.variables, expected_variable_values): + self.assertAllClose(expected_value, self.evaluate(variable)) + self.assertAllClose(self.evaluate(restore_layer_output), + self.evaluate(cudnn_output)[-1, -1:, ...]) + + def _CheckpointableSingleCellUnidirectionalTestTemplate( + self, single_cell_fn, cudnn_cell_fn): + # Single-layer cuDNN cells with object-based checkpointing should be + # checkpoint compatible with either single CudnnCompatible cells or + # MultiRnnCells with one cell. + input_size = 3 + save_cell_layer = single_cell_fn() + save_cell_layer( + inputs=array_ops.ones([1, input_size]), + state=save_cell_layer.zero_state(1, dtypes.float32)) + self.assertTrue(save_cell_layer.variables) + expected_values = [] + np.random.seed(10) + for variable in save_cell_layer.variables: + value = np.random.normal(size=variable.shape) + expected_values.append(value) + self.evaluate(variable.assign(value)) + save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + first_save_path = save_checkpoint.save(checkpoint_prefix) + self._VerifyCheckpoint( + checkpoint_path=first_save_path, + compatible_cell_fn= + lambda: rnn_cell_impl.MultiRNNCell([single_cell_fn()]), + cudnn_cell_fn=cudnn_cell_fn, + num_layers=1, + expected_variable_values=expected_values, + input_size=input_size) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + @test_util.run_in_graph_and_eager_modes() + def testLSTMCheckpointableSingleLayer(self): + num_units = 2 + direction = CUDNN_RNN_UNIDIRECTION + self._CheckpointableSingleCellUnidirectionalTestTemplate( + single_cell_fn=functools.partial( + cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), + cudnn_cell_fn=functools.partial( + cudnn_rnn.CudnnLSTM, num_layers=1, num_units=num_units, + direction=direction, name="awesome_lstm")) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + @test_util.run_in_graph_and_eager_modes() + def testGRUCheckpointableSingleLayer(self): + num_units = 2 + direction = CUDNN_RNN_UNIDIRECTION + with self.assertRaises(NotImplementedError): + # TODO(allenl): Implement object-based saving for GRUs and other cells. + self._CheckpointableSingleCellUnidirectionalTestTemplate( + single_cell_fn=functools.partial( + cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units), + cudnn_cell_fn=functools.partial( + cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units, + direction=direction, name="awesome_gru")) + + def _CheckpointableMultiLayerTestTemplate( + self, single_cell_fn, cudnn_cell_fn, num_layers): + + def _MultiCellFn(): + return rnn_cell_impl.MultiRNNCell( + [single_cell_fn() for _ in range(num_layers)]) + input_size = 3 + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session(graph=save_graph): + save_layer = _MultiCellFn() + save_layer(inputs=array_ops.ones([1, input_size]), + state=save_layer.zero_state(1, dtypes.float32)) + self.assertTrue(save_layer.variables) + expected_values = [] + np.random.seed(10) + for variable in save_layer.variables: + value = np.random.normal(size=variable.shape) + expected_values.append(value) + self.evaluate(variable.assign(value)) + save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + first_save_path = save_checkpoint.save(checkpoint_prefix) + self._VerifyCheckpoint( + checkpoint_path=first_save_path, + compatible_cell_fn=_MultiCellFn, cudnn_cell_fn=cudnn_cell_fn, + num_layers=num_layers, + expected_variable_values=expected_values, + input_size=input_size) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + @test_util.run_in_graph_and_eager_modes() + def testCudnnCompatibleLSTMCheckpointablMultiLayer(self): + num_units = 2 + num_layers = 3 + direction = CUDNN_RNN_UNIDIRECTION + self._CheckpointableMultiLayerTestTemplate( + single_cell_fn=functools.partial( + cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units), + cudnn_cell_fn=functools.partial( + cudnn_rnn.CudnnLSTM, num_layers=num_layers, num_units=num_units, + direction=direction, name="awesome_lstm"), + num_layers=num_layers) + + # TODO(jamesqin): Transform to parameterized test after it is included in the # TF open source codebase. -class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): +class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") @@ -884,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): rtol=2e-5) -class CudnnRNNTestParamsSize(TensorFlowTestCase): +class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase): def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, dtype, direction): @@ -931,7 +1070,7 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): dtype, direction) -class CudnnRNNTestTraining(TensorFlowTestCase): +class CudnnRNNTestTraining(test_util.TensorFlowTestCase): def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): """Compute the numeric gradient of y wrt to x. diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 36fba917a8f56c26fd5b4c3468d1d980a8ba2ba5..d58198faf353aab68430d2fa153a18de359112de 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -142,6 +142,9 @@ class _CudnnRNN(base_layer.Layer): """ # pylint:enable=line-too-long + # TODO(allenl): Document object-based saving and checkpoint compatibility once + # it's implemented for more cuDNN Layers. + # The following are constants defined by subclasses. # Type of RNN cell. _rnn_mode = None @@ -355,7 +358,8 @@ class _CudnnRNN(base_layer.Layer): "CUDA/CuDNN generations.") # Initialize opaque params with a tensor. self.kernel = vs.get_variable( - "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + "opaque_kernel", dtype=self._plain_dtype, + initializer=opaque_params_t, validate_shape=False) # Create saveable in the outer scope of the cudnn subgraph, such that # alternative subgraph with platform-independent rnn cells can load the # checkpoints directly. @@ -363,6 +367,11 @@ class _CudnnRNN(base_layer.Layer): self._create_saveable() self.built = True + def _gather_saveables_for_checkpoint(self): + raise NotImplementedError( + "This cell does not yet support object-based saving. File a feature " + "request if this limitation bothers you.") + def call(self, inputs, initial_state=None, training=True): """Runs the forward step for the RNN model. @@ -499,6 +508,8 @@ class _CudnnRNN(base_layer.Layer): direction=self.direction, scope=vs.get_variable_scope(), name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) + self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access + checkpointable=self, dtype=self._plain_dtype) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) @@ -521,6 +532,16 @@ class CudnnLSTM(_CudnnRNN): return ([self.num_layers * self.num_dirs, batch_size, self.num_units], [self.num_layers * self.num_dirs, batch_size, self.num_units]) + @property + def _gather_saveables_for_checkpoint(self): + if self._direction == CUDNN_RNN_UNIDIRECTION: + # Skip one inheritance level to avoid NotImplementedError. + return super(_CudnnRNN, self)._gather_saveables_for_checkpoint + else: + raise NotImplementedError( + "Object-based saving does not currently support bidirectional LSTM " + "cells. File a feature request if this limitation bothers you.") + class _CudnnRNNNoInputC(_CudnnRNN): """Abstract simple CudnnRNN layer without input_c.""" diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index e87162f0ee9cc4eed795555171f55a93639e83cf..b615824460be59697d2207f5d5af0eba748c5237 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -17,27 +17,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops +from tensorflow.contrib.eager.python import checkpointable_utils from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.layers import base as base_layer +from tensorflow.python.keras._impl.keras.engine import base_layer from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import resource_loader +from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import saver -_cudnn_rnn_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile("_cudnn_rnn_ops.so")) - CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" CUDNN_LSTM = "lstm" @@ -91,19 +88,23 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python - r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate - h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 + # reset gate + $$r_t = \sigma(x_t * W_r + h_t-1 * R_h + b_{Wr} + b_{Rr})$$ + # update gate + $$u_t = \sigma(x_t * W_u + h_t-1 * R_u + b_{Wu} + b_{Ru})$$ + # new memory gate + $$h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_{Rh}) + b_{Wh})$$ + $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$ ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): ```python - h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_Wh) # new memory gate + # new memory gate + \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\) ``` which is not equivalent to Cudnn GRU: in addition to the extra bias term b_Rh, ```python - r .* (h * R) != (r .* h) * R + \\(r .* (h * R) != (r .* h) * R\\) ``` """ @@ -267,13 +268,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): # instead of having the master pull all slices and then save them. slice_spec = "" params = weights + biases - param_names = weight_names + bias_names + self._weight_names = weight_names + self._bias_names = bias_names + self._param_names = weight_names + bias_names + prefixed_param_names = weight_names + bias_names if self._scope: - param_names = ["%s/%s" % (self._scope, pn) for pn in param_names] - + prefixed_param_names = [ + "%s/%s" % (self._scope, pn) for pn in prefixed_param_names] specs = [ saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) - for param, param_name in zip(params, param_names) + for param, param_name in zip(params, prefixed_param_names) ] super(CudnnOpaqueParamsSaveable, self).__init__( array_ops.identity(self._variables), specs, name) @@ -286,6 +290,45 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): return state_ops.assign( self._variables, opaque_params, validate_shape=False) + def _checkpointable_save(self, save_buffer): + weights, biases = self._OpaqueParamsToCanonical() + with ops.device("gpu:0"): + (weights, _), (biases, _) = self._TransformCanonical( + weights, biases) + for name, tensor in zip(self._param_names, weights + biases): + save_buffer[name] = array_ops.identity(tensor) + + def _checkpointable_restore(self, restore_buffer): + tensors = [array_ops.identity(restore_buffer[name]) + for name in self._param_names] + return self.restore( + restored_tensors=tensors, + restored_shapes=None # Unused + ) + + def _add_checkpointable_dependencies(self, checkpointable, dtype): + """Add canonical weight dependencies to `checkpointable`. + + When saving or restoring, converts to or from the opaque buffer + format. Weights are saved and loaded in the configuration expected by + cuDNN-compatible cells. + + Args: + checkpointable: An object inheriting from `CheckpointableBase` to add + dependencies too (typically the cuDNN `Layer`). + dtype: The dtype for the canonical parameter Tensors. + """ + split_dependencies = checkpointable_utils.split_dependency( + component_names=self._param_names, + component_dtypes=(dtype,) * len(self._param_names), + fill_save_buffer_fn=self._checkpointable_save, + consume_restore_buffer_fn=self._checkpointable_restore) + self._checkpointable_track_params(checkpointable, split_dependencies) + + def _checkpointable_track_params(self, checkpointable, params): + """Tracks parameters in a canonical configuration.""" + return # NotImplementedError raised by the Layer. + def _TFCanonicalNamePrefix(self, layer, is_fwd=True): if self._direction == CUDNN_RNN_UNIDIRECTION: return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name) @@ -481,10 +524,7 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): _rnn_mode = CUDNN_LSTM _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER - # pylint:disable=protected-access - _rnn_cell_name = base_layer._to_snake_case(CudnnCompatibleLSTMCell.__name__) - - # pylint:enable=protected-access + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleLSTMCell.__name__) def _cudnn_to_tf_gate_params(self, *cu_gate_order): i_g, f_g, c_g, o_g = cu_gate_order @@ -575,6 +615,29 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable): tf_biases.append(b) tf_bias_names.append(prefix + "/bias") + def _checkpointable_track_params(self, checkpointable, params): + """Track parameters for compatibility with CudnnCompatibleLSTMCell.""" + biases = [] + weights = [] + for name in self._weight_names: + weights.append(params[name]) + for name in self._bias_names: + biases.append(params[name]) + assert len(params) == len(weights) + len(biases) + if len(weights) == 1 and len(biases) == 1: + # For single-layer cells, allow substituting a cell with no MultiRNNCell + # wrapping. + kernel, = weights # pylint: disable=unbalanced-tuple-unpacking + bias, = biases # pylint: disable=unbalanced-tuple-unpacking + checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access + checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access + assert len(biases) == len(weights) + for cell_index, (bias, kernel) in enumerate(zip(biases, weights)): + cell = checkpointable_lib.Checkpointable() + checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access + cell.bias = bias + cell.kernel = kernel + class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): """SaveableObject implementation handling Cudnn GRU opaque params.""" @@ -582,10 +645,7 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): _rnn_mode = CUDNN_GRU _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER - # pylint:disable=protected-access - _rnn_cell_name = base_layer._to_snake_case(CudnnCompatibleGRUCell.__name__) - - # pylint:enable=protected-access + _rnn_cell_name = base_layer.to_snake_case(CudnnCompatibleGRUCell.__name__) def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -664,11 +724,7 @@ class CudnnGRUSaveable(CudnnOpaqueParamsSaveable): class CudnnRNNSimpleSaveable(CudnnLSTMSaveable): """SaveableObject implementation handling Cudnn RNN Tanh opaque params.""" - # pylint:disable=protected-access - _rnn_cell_name = base_layer._to_snake_case( - rnn_cell_impl.BasicRNNCell.__name__) - - # pylint:enable=protected-access + _rnn_cell_name = base_layer.to_snake_case(rnn_cell_impl.BasicRNNCell.__name__) def _cudnn_to_tf_weights(self, *cu_weights): r"""Stitching cudnn canonical weights to generate tf canonical weights.""" @@ -1584,31 +1640,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER -@ops.RegisterGradient("CudnnRNN") -def _cudnn_rnn_backward(op, *grad): - if not op.get_attr("is_training"): - raise ValueError( - "CudnnRNN must set is_training to True to be used in gradients") - return gen_cudnn_rnn_ops.cudnn_rnn_backprop( - input=op.inputs[0], - input_h=op.inputs[1], - input_c=op.inputs[2], - params=op.inputs[3], - output=op.outputs[0], - output_h=op.outputs[1], - output_c=op.outputs[2], - output_backprop=grad[0], - output_h_backprop=grad[1], - output_c_backprop=grad[2], - reserve_space=op.outputs[3], - dropout=op.get_attr("dropout"), - seed=op.get_attr("seed"), - seed2=op.get_attr("seed2"), - rnn_mode=op.get_attr("rnn_mode"), - input_mode=op.get_attr("input_mode"), - direction=op.get_attr("direction")) - - ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 0458199ff771bc45603106411550a39448e515b8..8bdbba83ef6a8541158d956e36caf6a9be435c5b 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -8,6 +8,11 @@ load( "//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_libs", + "if_not_windows", +) +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "if_static", ) py_library( @@ -17,35 +22,25 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/contrib/data/python/ops:shuffle_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", ], ) +cc_library( + name = "lib_proto_parsing_for_dataset_ops", + deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]), +) + tf_custom_op_library( name = "_dataset_ops.so", srcs = ["ops/dataset_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"], + deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] + + if_static( + extra_deps = [":lib_proto_parsing_for_dataset_ops"], + otherwise = [], + ), ) tf_gen_op_libs( op_lib_names = ["dataset_ops"], ) - -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index fcdccdd26ca1824bf13f1fd0cfd80b20ca8a10c3..077cbba9d2ae41a83f6c358a63ae27aec5741e2c 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,20 +23,28 @@ removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@SqlDataset +@@assert_element_shape @@batch_and_drop_remainder +@@bucket_by_sequence_length @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ignore_errors +@@make_batched_features_dataset +@@make_csv_dataset @@make_saveable_from_iterator @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave +@@prefetch_to_device @@read_batch_features @@rejection_resample +@@sample_from_datasets @@scan @@shuffle_and_repeat +@@sliding_window_batch @@sloppy_interleave @@unbatch @@ -49,6 +57,7 @@ from __future__ import print_function # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops.batching import assert_element_shape from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import map_and_batch @@ -58,16 +67,25 @@ from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element +from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave +from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset +from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat +from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) + +# A constant that can be used to enable auto-tuning. +AUTOTUNE = -1 diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 56471911c5c0d1c1825955c67997b5bbc0786463..c56910c7833d4c54fa8db27cd061b404013f3f54 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -9,6 +9,18 @@ exports_files(["LICENSE"]) cc_library( name = "prefetching_kernels", srcs = ["prefetching_kernels.cc"], + deps = [ + "//tensorflow/core:core_cpu_headers_lib", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "directed_interleave_dataset_op", + srcs = ["directed_interleave_dataset_op.cc"], deps = [ "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -28,24 +40,36 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "threadpool_dataset_op", + srcs = ["threadpool_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( + name = "unique_dataset_op", + srcs = ["unique_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "dataset_kernels", deps = [ + ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":prefetching_kernels", + ":threadpool_dataset_op", + ":unique_dataset_op", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..48d3734162525ffc6ace076e4f0523c1d0cae511 --- /dev/null +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -0,0 +1,274 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + DatasetBase* selector_input; + OP_REQUIRES_OK(ctx, + GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); + + OP_REQUIRES( + ctx, + selector_input->output_dtypes().size() == 1 && + selector_input->output_dtypes()[0] == DT_INT64 && + selector_input->output_shapes().size() == 1 && + selector_input->output_shapes()[0].IsCompatibleWith( + PartialTensorShape({})), + errors::InvalidArgument( + "The selector input must be a dataset of scalar int64 elements.")); + + std::vector data_inputs; + for (size_t i = 1; i < ctx->num_inputs(); ++i) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); + data_inputs.push_back(input); + + OP_REQUIRES( + ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), + errors::InvalidArgument( + "All inputs must have the same output_dtypes. First input " + "has types ", + DataTypeVectorString(data_inputs[0]->output_dtypes()), + ", and input ", i - 1, " has types ", + DataTypeVectorString(input->output_dtypes()))); + } + *output = new Dataset(ctx, selector_input, std::move(data_inputs)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, + std::vector data_inputs) + : GraphDatasetBase(ctx), + selector_input_(selector_input), + data_inputs_(std::move(data_inputs)) { + selector_input_->Ref(); + + output_shapes_ = data_inputs_[0]->output_shapes(); + data_inputs_[0]->Ref(); + for (size_t i = 1; i < data_inputs_.size(); ++i) { + const DatasetBase* data_input = data_inputs_[i]; + data_input->Ref(); + for (size_t j = 0; j < output_shapes_.size(); ++j) { + output_shapes_[j] = MostSpecificCompatibleShape( + output_shapes_[j], data_input->output_shapes()[j]); + } + } + } + + ~Dataset() override { + selector_input_->Unref(); + for (DatasetBase* data_input : data_inputs_) { + data_input->Unref(); + } + } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::DirectedInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return data_inputs_[0]->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { + return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* selector_input_node; + TF_RETURN_IF_ERROR( + b->AddParentDataset(ctx, selector_input_, &selector_input_node)); + std::vector data_input_nodes(data_inputs_.size()); + for (size_t i = 0; i < data_inputs_.size(); ++i) { + TF_RETURN_IF_ERROR( + b->AddParentDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + } + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, + {{1, data_input_nodes}}, {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + selector_input_impl_(params.dataset->selector_input_->MakeIterator( + params.prefix + ".selector")), + num_active_inputs_(params.dataset->data_inputs_.size()) { + data_input_impls_.reserve(params.dataset->data_inputs_.size()); + for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { + const DatasetBase* data_input = params.dataset->data_inputs_[i]; + data_input_impls_.push_back(data_input->MakeIterator( + strings::StrCat(params.prefix, "[", i, "]"))); + } + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!selector_input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + + while (true) { + std::vector selector_result; + *end_of_sequence = false; + TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( + ctx, &selector_result, end_of_sequence)); + if (*end_of_sequence) { + selector_input_impl_.reset(); + for (auto& data_input_impl : data_input_impls_) { + data_input_impl.reset(); + } + return Status::OK(); + } + + int64 selected_input = selector_result[0].scalar()(); + if (selected_input < 0 || selected_input > data_input_impls_.size()) { + return errors::InvalidArgument( + "Selector index out of range: ", selected_input, + " >= ", data_input_impls_.size()); + } + + if (data_input_impls_[selected_input]) { + bool end_of_selected_input = false; + TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( + ctx, out_tensors, &end_of_selected_input)); + + if (!end_of_selected_input) { + return Status::OK(); + } + + data_input_impls_[selected_input].reset(); + --num_active_inputs_; + + if (num_active_inputs_ == 0) { + selector_input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); + } + } + + LOG(WARNING) << "DirectedInterleave selected an exhausted input: " + << selected_input; + } + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (selector_input_impl_) { + TF_RETURN_IF_ERROR(SaveParent(writer, selector_input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const auto& data_input_impl = data_input_impls_[i]; + if (data_input_impl) { + TF_RETURN_IF_ERROR(SaveParent(writer, data_input_impl)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, selector_input_impl_)); + } else { + selector_input_impl_.reset(); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains(full_name( + strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR( + RestoreParent(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr selector_input_impl_ GUARDED_BY(mu_); + std::vector> data_input_impls_ + GUARDED_BY(mu_); + int64 num_active_inputs_ GUARDED_BY(mu_); + }; + + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); d++) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } + return output_tensorshape; + } + + const DatasetBase* const selector_input_; + const std::vector data_inputs_; + std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), + DirectedInterleaveDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index d3df14bdd03476e9ee4015b374512e5bb9893a63..a2bfce03620a1482f5b21cbf23c66833bc5cd480 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_op_kernel.h" @@ -35,38 +36,25 @@ using FunctionBufferCallback = std::function; class FunctionBufferingResource : public ResourceBase { public: FunctionBufferingResource(FunctionLibraryRuntime* lib, + std::unique_ptr pflr, const NameAttrList& func, int64 buffer_size, const string& source_device, const string& target_device, - const std::vector& func_args, - int64 thread_pool_size) + const std::vector& func_args) : lib_(lib), + pflr_(std::move(pflr)), func_(func), buffer_size_(buffer_size), source_device_(source_device), target_device_(target_device), func_args_(func_args), - thread_pool_(new thread::ThreadPool(Env::Default(), ThreadOptions(), - "buffer_resource", thread_pool_size, - false /* low_latency_hint */)), handle_(kInvalidHandle), is_buffering_(false), end_of_sequence_(false), - cancelled_(false) { - runner_ = [this](std::function c) { - thread_pool_->Schedule(std::move(c)); - }; - } + cancelled_(false) {} ~FunctionBufferingResource() override { Cancel(); - { - mutex_lock l(mu_); - while (is_buffering_) { - cond_var_.wait(l); - } - } - delete thread_pool_; } string DebugString() override { @@ -100,6 +88,20 @@ class FunctionBufferingResource : public ResourceBase { void Cancel() LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); cancelled_ = true; + while (is_buffering_) { + cond_var_.wait(l); + } + } + + // Cancels all pending operations and then clears out the state. + void Reset() LOCKS_EXCLUDED(mu_) { + Cancel(); + mutex_lock l(mu_); + buffer_.clear(); + requests_.clear(); + is_buffering_ = false; + end_of_sequence_ = false; + cancelled_ = false; } // If the buffer has anything, runs `callback` on the first element in the @@ -164,15 +166,12 @@ class FunctionBufferingResource : public ResourceBase { for (int i = 0; i < cancellation_callbacks.size(); ++i) { cancellation_callbacks[i](cancellation_buffer_elements[i]); } - // We only wait on cond_var_ in the destructor, so there would atmost be - // one waiter to notify. - cond_var_.notify_one(); + cond_var_.notify_all(); return; } FunctionLibraryRuntime::Options opts; // Copied from CapturedFunction::generate_step_id(); opts.step_id = -std::abs(static_cast(random::New64())); - opts.runner = &runner_; opts.source_device = source_device_; AllocatorAttributes arg_alloc_attr; arg_alloc_attr.set_on_host(true); @@ -191,13 +190,12 @@ class FunctionBufferingResource : public ResourceBase { mutex_lock l(mu_); BufferElement buffer_element; buffer_element.status = status; - if (!status.ok()) { + if (status.ok()) { + buffer_element.value.swap(*rets); + } else { end_of_sequence_ = true; is_buffering_ = false; - buffer_.push_back(std::move(buffer_element)); - return; } - buffer_element.value.swap(*rets); buffer_.push_back(std::move(buffer_element)); if (!requests_.empty()) { buffer_front = std::move(buffer_.front()); @@ -205,9 +203,16 @@ class FunctionBufferingResource : public ResourceBase { callback = std::move(requests_.front()); requests_.pop_front(); } - if (buffer_.size() < buffer_size_) { + if (buffer_.size() < buffer_size_ && !end_of_sequence_) { restart_buffering = true; } else { + // When the buffer is full, we don't want to call + // FillBuffer() unless we're in cancellation phase in which + // case FillBuffer() will do the final cleanup post + // cancellation. + if (cancelled_) { + restart_buffering = true; + } is_buffering_ = false; } } @@ -222,16 +227,15 @@ class FunctionBufferingResource : public ResourceBase { mutex mu_; FunctionLibraryRuntime* lib_; + std::unique_ptr pflr_; NameAttrList func_; const int64 buffer_size_; const string source_device_; const string target_device_; const std::vector func_args_; - thread::ThreadPool* thread_pool_; FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); std::deque buffer_ GUARDED_BY(mu_); std::deque requests_ GUARDED_BY(mu_); - std::function)> runner_ = nullptr; bool is_buffering_ GUARDED_BY(mu_); bool end_of_sequence_ GUARDED_BY(mu_); bool cancelled_ GUARDED_BY(mu_); @@ -241,12 +245,22 @@ class FunctionBufferingResource : public ResourceBase { class FunctionBufferResourceHandleOp : public OpKernel { public: explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx) { + : OpKernel(ctx), flib_def_(nullptr) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_)); + } + + ~FunctionBufferResourceHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } } void Compute(OpKernelContext* ctx) override { @@ -267,33 +281,44 @@ class FunctionBufferResourceHandleOp : public OpKernel { const string& source_device = ctx->device()->name(); - ContainerInfo cinfo; - OP_REQUIRES_OK(ctx, cinfo.Init(ctx->resource_manager(), def())); - // Create the resource. - FunctionBufferingResource* buffer; - OP_REQUIRES_OK( - ctx, ctx->resource_manager()->LookupOrCreate( - cinfo.container(), cinfo.name(), &buffer, - [lib, &source_device, &target_device, func_args, - this](FunctionBufferingResource** ptr) { - *ptr = new FunctionBufferingResource( - lib, func_, buffer_size_, source_device, target_device, - func_args, thread_pool_size_); - return Status::OK(); - })); - OP_REQUIRES_OK(ctx, buffer->Instantiate()); + mutex_lock l(mu_); + if (!initialized_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); + FunctionLibraryRuntime* clone_lib; + std::unique_ptr pflr; + OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib)); + // Create the resource. + FunctionBufferingResource* buffer; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &buffer, + [clone_lib, &pflr, &source_device, &target_device, func_args, + this](FunctionBufferingResource** ptr) { + *ptr = new FunctionBufferingResource( + clone_lib, std::move(pflr), func_, buffer_size_, + source_device, target_device, func_args); + return Status::OK(); + })); + core::ScopedUnref s(buffer); + OP_REQUIRES_OK(ctx, buffer->Instantiate()); + initialized_ = true; + } OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( - ctx, 0, cinfo.container(), cinfo.name(), + ctx, 0, cinfo_.container(), cinfo_.name(), MakeTypeIndex())); } private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + std::unique_ptr flib_def_; NameAttrList func_; int64 buffer_size_; string container_; string name_; - int64 thread_pool_size_; }; REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") @@ -334,25 +359,27 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC( ctx, LookupResource(ctx, handle, &buffer), done); - core::ScopedUnref s(buffer); if (buffer->Finished()) { + buffer->Unref(); ctx->SetStatus(errors::OutOfRange("end_of_sequence")); done(); return; } FunctionBufferCallback callback = - [ctx, done](const BufferElement& buffer_element) { + [ctx, buffer, done](const BufferElement& buffer_element) { Status s = buffer_element.status; if (!s.ok()) { ctx->SetStatus(s); + buffer->Unref(); done(); return; } for (size_t i = 0; i < buffer_element.value.size(); ++i) { ctx->set_output(i, buffer_element.value[i]); } + buffer->Unref(); done(); }; buffer->MaybeGet(std::move(callback)); @@ -374,4 +401,62 @@ REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") FunctionBufferingResourceGetNextOp); #endif // TENSORFLOW_USE_SYCL +// Resets the FunctionBufferingResource, cancelling all pending requests and +// clearing out the buffer. +class FunctionBufferingResourceResetOp : public OpKernel { + public: + explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + ~FunctionBufferingResourceResetOp() override {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle; + OP_REQUIRES_OK(ctx, + HandleFromInput(ctx, "function_buffer_resource", &handle)); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, handle, &buffer)); + core::ScopedUnref s(buffer); + + buffer->Reset(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#endif // TENSORFLOW_USE_SYCL + +class IteratorGetDeviceOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + // NOTE(mrry): We do not currently Validate that the handle + // corresponds to a real IteratorResource, because that symbol is + // not exposed from the framework library. + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + // NOTE(mrry): Since the operation's input is a resource, we must be + // colocated with it, and so we can simply return the current device's + // name without looking at the input. + device_name_t->scalar()() = ctx->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), + IteratorGetDeviceOp); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..63e19ae3f837c9d3cfb1221df64360ee74117f13 --- /dev/null +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -0,0 +1,193 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace { + +class ThreadPoolResource : public ResourceBase { + public: + ThreadPoolResource(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads, bool low_latency_hint) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { + } + + // Schedules fn() for execution in the pool of threads. + void Schedule(std::function fn) { + thread_pool_.Schedule(std::move(fn)); + } + + string DebugString() override { return "ThreadPoolResource"; } + + private: + thread::ThreadPool thread_pool_; +}; + +// Creates a handle to a ThreadPool resource. Note that we don't use +// ResourceOpKernel here because the ThreadPoolResource constructor requires +// access to `OpKernelContext::env()`, which isn't provided by +// `ResourceOpKernel::CreateResource()`. +class ThreadPoolHandleOp : public OpKernel { + public: + explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES( + ctx, num_threads_ > 0, + errors::InvalidArgument("`num_threads` must be greater than zero.")); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ThreadPoolHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + ThreadPoolResource* resource; + OP_REQUIRES_OK(ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](ThreadPoolResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new ThreadPoolResource( + ctx->env(), {}, display_name_, + num_threads_, + false /* low_latency_hint */); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + string display_name_; + int num_threads_; +}; + +class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + ThreadPoolResource* threadpool_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &threadpool_resource)); + core::ScopedUnref unref_iterator(threadpool_resource); + + *output = new Dataset(ctx, input, threadpool_resource); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + ThreadPoolResource* threadpool) + : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) { + input_->Ref(); + threadpool_->Ref(); + } + + ~Dataset() override { + input_->Unref(); + threadpool_->Unref(); + } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented( + "Cannot currently serialize the thread pool for a " + "ThreadPoolDataset."); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + ThreadPoolResource* pool = dataset()->threadpool_; + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = [pool](std::function c) { + pool->Schedule(std::move(c)); + }; + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.lib = ctx->lib(); + params.function_library = ctx->function_library(); + params.allocator_getter = ctx->allocator_getter(); + IteratorContext threadpool_ctx(params); + return input_impl_->GetNext(&threadpool_ctx, out_tensors, + end_of_sequence); + } + + private: + std::unique_ptr input_impl_; + }; + + const DatasetBase* const input_; + ThreadPoolResource* const threadpool_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU), + ThreadPoolHandleOp); +REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU), + ThreadPoolDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc similarity index 99% rename from tensorflow/core/kernels/data/unique_dataset_op.cc rename to tensorflow/contrib/data/kernels/unique_dataset_op.cc index 7726ee0edf71b34cb65fe5fceb2b60dd30bb58e2..69fbb0fcdcce87951d2c9b84210fda378081b103 100644 --- a/tensorflow/core/kernels/data/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 289ffa1d9c29092cdf434e86ed5553ff9644d43e..137deb63527f0bdde7da8d5be83ed038f430e581 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -17,6 +17,23 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("DirectedInterleaveDataset") + .Input("selector_input_dataset: variant") + .Input("data_input_datasets: N * variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("N: int >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A substitute for `InterleaveDataset` on a fixed list of `N` datasets. + +selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines + which of the `N` data inputs should produce the next output element. +data_input_datasets: `N` datasets with the same type that will be interleaved + according to the values of `selector_input_dataset`. +)doc"); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") @@ -27,6 +44,24 @@ REGISTER_OP("IgnoreErrorsDataset") Creates a dataset that contains the elements of `input_dataset` ignoring errors. )doc"); +REGISTER_OP("UniqueDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the unique elements of `input_dataset`. +)doc"); + +REGISTER_OP("IteratorGetDevice") + .Input("resource: resource") + .Output("device: string") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Returns the name of the device on which `resource` has been placed. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") @@ -35,7 +70,6 @@ REGISTER_OP("FunctionBufferingResource") .Attr("container: string") .Attr("f: func") .Attr("buffer_size: int") - .Attr("thread_pool_size: int") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Creates a resource that fills up a buffer by making function calls. @@ -45,7 +79,6 @@ target_device: Target device to execute the function on. resource: Handle to the resource created. f: Function to be executed. buffer_size: Size of the buffer. -thread_pool_size: Size of the threadpool doing the prefetching. container: If non-empty, this resource is placed in the given container. Otherwise, a default container is used. shared_name: If non-empty, this resource will be shared under the given name @@ -65,4 +98,42 @@ output: A list of return values. output_types: The type list for the return values. )doc"); +REGISTER_OP("FunctionBufferingResourceReset") + .Input("function_buffer_resource: resource") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Resets the FunctionBufferingResource. + +function_buffer_resource: The FunctionBufferingResource handle. +)doc"); + +REGISTER_OP("ThreadPoolDataset") + .Input("input_dataset: variant") + .Input("thread_pool: resource") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that uses a custom thread pool to compute `input_dataset`. + +handle: A resource produced by the ThreadPoolHandle op. +)doc"); + +REGISTER_OP("ThreadPoolHandle") + .Output("handle: resource") + .SetShapeFn(shape_inference::ScalarShape) + .Attr("num_threads: int") + .Attr("display_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Doc(R"doc( +Creates a custom thread pool with the given number of threads. + +handle: A resource that can be consumed by one or more ThreadPoolDataset ops. +num_threads: The number of threads in the thread pool. +display_name: A human-readable name for the threads that may be visible in + some visualizations. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index e51d57cc896dc32d8e11912cd89f34a04a858c78..b15b9663f4c1bd48ed2a482f657490b0b5677673 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,24 +4,24 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", - size = "small", + size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", @@ -37,8 +37,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -59,10 +58,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -79,8 +78,7 @@ py_test( ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -127,13 +125,13 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -145,7 +143,7 @@ tf_py_test( additional_deps = [ ":dataset_serialization_test", "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -168,13 +166,14 @@ py_test( srcs = ["interleave_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ + "manual", "no_oss", "no_pip", + "notap", ], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -185,6 +184,7 @@ py_test( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -195,7 +195,8 @@ tf_py_test( srcs = ["get_single_element_test.py"], additional_deps = [ "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:get_single_element", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -213,8 +214,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -259,8 +259,8 @@ py_test( srcs_version = "PY2AND3", deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:counter", + "//tensorflow/contrib/data/python/ops:enumerate_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -272,6 +272,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -295,6 +296,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", ], ) @@ -306,12 +308,12 @@ py_test( srcs_version = "PY2AND3", tags = ["noasan"], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:resampling", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:string_ops", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -324,7 +326,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -343,11 +345,11 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -375,7 +377,6 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -412,10 +413,25 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "threadpool_dataset_ops_test", + size = "small", + srcs = ["threadpool_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:threadpool", + "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -427,13 +443,13 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/contrib/stateless", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -446,25 +462,20 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) -py_test( +cuda_py_test( name = "prefetching_ops_test", size = "small", srcs = ["prefetching_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "manual", - "no_oss", # b/68785503 - ], - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -479,16 +490,19 @@ py_test( ], ) -filegroup( - name = "all_files", - srcs = glob( - include = [ - "**/*", - ], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +tf_py_test( + name = "slide_dataset_op_test", + size = "small", + srcs = ["slide_dataset_op_test.py"], + additional_deps = [ + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sliding", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 71dc1c1172c9d515d4c85f85257c952135098329..413d8737978b695ac443c92036d6641e5c73f28c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -28,8 +28,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test @@ -311,10 +313,10 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> - # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) @@ -381,11 +383,51 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testBatchAndMapDataset(self): - return self._testBatchAndMapDatasetHelper() + def testMapAndBatchDataset(self): + return self._testMapAndBatchDatasetHelper() - def testBatchAndMapDatasetWithParallelBatching(self): - return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchDatasetWithParallelBatching(self): + return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + iterator = ( + dataset_ops.Dataset.range(10).apply( + batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), + batch_size=4, + drop_remainder=drop_remainder)).make_one_shot_iterator()) + if drop_remainder: + self.assertEqual([4, 1], iterator.output_shapes.as_list()) + else: + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + if not drop_remainder: + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMapAndBatchPartialBatch(self): + return self._testMapAndBatchPartialBatchHelper() + + def testMapAndBatchPartialBatchDropRemainder(self): + return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) + + def testMapAndBatchYieldsPartialBatch(self): + iterator = (dataset_ops.Dataset.range(10) + .apply(batching.map_and_batch( + lambda x: array_ops.reshape(x * x, [1]), 4)) + .make_one_shot_iterator()) + self.assertEqual([None, 1], iterator.output_shapes.as_list()) + next_element = iterator.get_next() + with self.test_session() as sess: + self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) + self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) + self.assertAllEqual([[64], [81]], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) def testMapAndBatchSparse(self): @@ -411,7 +453,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testBatchAndMapDatasetFails(self): + def testMapAndBatchDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -425,7 +467,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testBatchAndMapDatasetShapeMismatch(self): + def testMapAndBatchDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -539,5 +581,73 @@ class PaddedBatchDatasetSerializationTest( lambda: build_dataset(seq_lens2), 8) +class RestructuredDatasetTest(test.TestCase): + + def test_assert_element_shape(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + expected_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 4))) + result = dataset.apply(batching.assert_element_shape(expected_shapes)) + self.assertEqual(expected_shapes, result.output_shapes) + + iterator = result.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def test_assert_wrong_element_shape(self): + + def create_dataset(_): + return (array_ops.ones(2, dtype=dtypes.float32), + array_ops.zeros((3, 4), dtype=dtypes.int32)) + + dataset = dataset_ops.Dataset.range(3).map(create_dataset) + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + with self.assertRaises(ValueError): + dataset.apply(batching.assert_element_shape(wrong_shapes)) + + def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): + + def create_unknown_shape_dataset(x): + return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) + + dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) + unknown_shapes = (tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)) + self.assertEqual(unknown_shapes, dataset.output_shapes) + + wrong_shapes = (tensor_shape.TensorShape(2), + tensor_shape.TensorShape((3, 10))) + iterator = ( + dataset.apply(batching.assert_element_shape(wrong_shapes)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index f1b494e1a620992365ed75613b508e32f94b40a4..6002cc73c8b41c2f20beaf0158af813807e58c90 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base @@ -102,6 +104,21 @@ class GroupByWindowTest(test.TestCase): self.assertAllEqual([0, 0, 0], sess.run(get_next)) self.assertAllEqual([1], sess.run(get_next)) + def testEmpty(self): + iterator = ( + dataset_ops.Dataset.range(4).apply( + grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Window size must be greater than zero, but got 0."): + print(sess.run(get_next)) + def testReduceFuncError(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -379,5 +396,118 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +class BucketBySequenceLength(test.TestCase): + + def testBucket(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25, 35] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + def testPadToBoundary(self): + + boundaries = [10, 20, 30] + batch_sizes = [10, 8, 4, 2] + lengths = [8, 13, 25] + + def element_gen(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes[:-1], lengths): + for _ in range(batch_size): + elements.append([1] * length) + random.shuffle(elements) + for el in elements: + yield (el,) + for _ in range(batch_sizes[-1]): + el = [1] * (boundaries[-1] + 5) + yield (el,) + + element_len = lambda el: array_ops.shape(el)[0] + dataset = dataset_ops.Dataset.from_generator( + element_gen, (dtypes.int64,), ([None],)).apply( + grouping.bucket_by_sequence_length( + element_len, boundaries, batch_sizes, + pad_to_bucket_boundary=True)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + batches = [] + for _ in range(3): + batches.append(sess.run(batch)) + with self.assertRaisesOpError("bucket_boundaries"): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + batch_size = batch.shape[0] + length = batch.shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + batch_sizes = batch_sizes[:-1] + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(boundaries), sorted(lengths_val)) + + def testTupleElements(self): + + def elements_gen(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (x, y) + + def element_length_fn(x, y): + del y + return array_ops.shape(x)[0] + + dataset = dataset_ops.Dataset.from_generator( + generator=elements_gen, + output_shapes=(tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([])), + output_types=(dtypes.int32, dtypes.int32)) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8])) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index dbc35097ddda9f0375060d43aeb43efa8107f929..78ecce8f7daaf84002ae78d8d77820755b967d89 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -163,7 +163,7 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs, sparse_tensors=False, verify_exhausted=True): - """Verifies that restoring into an already initilized iterator works. + """Verifies that restoring into an already initialized iterator works. Args: ds_fn: See `run_core_tests`. diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index 32ea44f7c7ba329dc253bb9fbbcac0a1ed16aec7..87b7c6ddb7afcbaaf8fe97cd8be87e6f5af8cd4d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -22,6 +22,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -33,17 +34,25 @@ class GetSingleElementTest(test.TestCase): take_value = array_ops.placeholder_with_default( constant_op.constant(1, dtype=dtypes.int64), shape=[]) + def make_sparse(x): + x_1d = array_ops.reshape(x, [1]) + x_2d = array_ops.reshape(x, [1, 1]) + return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d) + dataset = (dataset_ops.Dataset.range(100) .skip(skip_value) - .map(lambda x: x * x) + .map(lambda x: (x * x, make_sparse(x))) .take(take_value)) element = get_single_element.get_single_element(dataset) with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + for x in [0, 5, 10]: + dense_val, sparse_val = sess.run(element, feed_dict={skip_value: x}) + self.assertEqual(x * x, dense_val) + self.assertAllEqual([[x]], sparse_val.indices) + self.assertAllEqual([x], sparse_val.values) + self.assertAllEqual([x], sparse_val.dense_shape) with self.assertRaisesRegexp(errors.InvalidArgumentError, "Dataset was empty."): diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 256ad8d94dc1a7c2b26df3f1ebf8e8e321882c15..43aa4b1bd02791ff304a990c0bbe8e45534c0c77 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -30,6 +30,7 @@ from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -94,6 +95,76 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -338,7 +409,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): """Tests where all the workers race in producing elements. - Note: this is in contrast with the prevous test which carefully sequences + Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. Args: @@ -424,7 +495,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): """Tests where all the workers race in producing elements. - Note: this is in contrast with the prevous test which carefully sequences + Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. @@ -836,5 +907,114 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 10000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index dc3e38db59301bf1819999f479171af35930e9d2..b08132cd72254326d965907a1fdafb8a820926a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import threading from tensorflow.contrib.data.python.ops import prefetching_ops @@ -26,37 +25,43 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class StagingAreaOpsTest(test.TestCase): +class PrefetchingKernelsOpsTest(test.TestCase): def setUp(self): self._event = threading.Event() - def _prefetch_fn_helper(self, buffer_name, device0, device1): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + def _create_ds_and_iterator(self, device0, initializable=False): def gen(): - for i in itertools.count(start=1, step=1): - yield [i + 0.0] + for i in range(1, 10): + yield [float(i)] if i == 6: self._event.set() with ops.device(device0): - dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() + ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) + if initializable: + ds_iterator = ds.make_initializable_iterator() + else: + ds_iterator = ds.make_one_shot_iterator() + return (ds, ds_iterator) + + def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1): + ds_iterator_handle = ds_iterator.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) + h, ds.output_types, ds.output_shapes) return remote_iterator.get_next() target = constant_op.constant(device0) @@ -64,15 +69,28 @@ class StagingAreaOpsTest(test.TestCase): buffer_resource_handle = prefetching_ops.function_buffering_resource( f=_remote_fn, target_device=target, - string_arg=iterator_3_handle, + string_arg=ds_iterator_handle, buffer_size=3, - thread_pool_size=2, shared_name=buffer_name) with ops.device(device1): prefetch_op = prefetching_ops.function_buffering_resource_get_next( function_buffer_resource=buffer_resource_handle, output_types=[dtypes.float32]) + reset_op = prefetching_ops.function_buffering_resource_reset( + function_buffer_resource=buffer_resource_handle) + destroy_op = resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True) + + return (prefetch_op, reset_op, destroy_op) + + def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False) + prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name, + device0, device1) with self.test_session(config=worker_config) as sess: elem = sess.run(prefetch_op) @@ -86,26 +104,277 @@ class StagingAreaOpsTest(test.TestCase): self._event.wait() elem = sess.run(prefetch_op) self.assertEqual(elem, [5.0]) - sess.run( - resource_variable_ops.destroy_resource_op( - buffer_resource_handle, ignore_lookup_error=True)) + sess.run(destroy_op) def testSameDeviceCPU(self): - self._prefetch_fn_helper("same_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:0") + self._prefetch_fn_helper_one_shot("same_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:0") def testDifferentDeviceCPU(self): - self._prefetch_fn_helper("diff_device_cpu", - "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/cpu:1") + self._prefetch_fn_helper_one_shot("diff_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:1") def testDifferentDeviceCPUGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") - self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0", - "/job:localhost/replica:0/task:0/gpu:0") + self._prefetch_fn_helper_one_shot("cpu_gpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/gpu:0") + + def testReinitialization(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + # Lets reset the function buffering resource and reinitialize the + # iterator. Should be able to go through this again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run(destroy_op) + + def testReinitializationOutOfRange(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + device0 = "/job:localhost/replica:0/task:0/cpu:0" + device1 = "/job:localhost/replica:0/task:0/cpu:1" + ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True) + prefetch_op, reset_op, destroy_op = self._create_ops( + ds, ds_iterator, "reinit", device0, device1) + + with self.test_session(config=worker_config) as sess: + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + # Now reset everything and try it out again. + self._event.clear() + sess.run(reset_op) + sess.run(ds_iterator.initializer) + for i in range(1, 10): + elem = sess.run(prefetch_op) + self.assertEqual(elem, [float(i)]) + # Try fetching after its over twice to test out end of sequence. + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(prefetch_op) + + sess.run(destroy_op) + + +class PrefetchToDeviceTest(test.TestCase): + + def testPrefetchToDevice(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchDictToDevice(self): + host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element["a"].dtype) + self.assertEqual([], next_element["a"].shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + self.assertEqual({"a": i}, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchSparseTensorsToDevice(self): + def make_tensor(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2]) + host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) + + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_one_shot_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + for i in range(10): + actual = sess.run(next_element) + self.assertAllEqual([i], actual.values) + self.assertAllEqual([[0, 0]], actual.indices) + self.assertAllEqual([2, 2], actual.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpu(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceWithReInit(self): + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/cpu:1")) + + # NOTE(mrry): This device block creates the "host" dataset and iterator on + # /cpu:0, and ensures that the prefetching is across devices. In typical use + # this would not be necessary, because the GPU device would not support any + # of the dataset-related ops. + with ops.device("/cpu:0"): + iterator = device_dataset.make_initializable_iterator() + + self.assertEqual(host_dataset.output_types, device_dataset.output_types) + self.assertEqual(host_dataset.output_types, iterator.output_types) + self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes) + self.assertEqual(host_dataset.output_shapes, iterator.output_shapes) + self.assertEqual(host_dataset.output_classes, device_dataset.output_classes) + self.assertEqual(host_dataset.output_classes, iterator.output_classes) + + next_element = iterator.get_next() + self.assertEqual(dtypes.int64, next_element.dtype) + self.assertEqual([], next_element.shape) + + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + with self.test_session(config=worker_config) as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToDeviceGpuWithReInit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(5): + self.assertEqual(i, sess.run(next_element)) + sess.run(iterator.initializer) + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 6efe97444a375febc550ff3a3ea04bcd9330a3a5..1075302bae96ca2e0111efbacdf5e919ea76897d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,8 @@ import gzip import os import zlib +import numpy as np + from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 @@ -262,12 +264,20 @@ class ReadBatchFeaturesTest(test.TestCase): self._num_records = 7 self.test_filenames = self._createFiles() - def _read_batch_features(self, filenames, num_epochs, batch_size): + def _read_batch_features(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): self.filenames = filenames self.num_epochs = num_epochs self.batch_size = batch_size - return readers.read_batch_features( + return readers.make_batched_features_dataset( file_pattern=self.filenames, batch_size=self.batch_size, features={ @@ -276,22 +286,29 @@ class ReadBatchFeaturesTest(test.TestCase): "keywords": parsing_ops.VarLenFeature(dtypes.string) }, reader=core_readers.TFRecordDataset, - randomize_input=False, - num_epochs=self.num_epochs) + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch).make_one_shot_iterator( + ).get_next() def _record(self, f, r): - example = example_pb2.Example(features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature(int64_list=feature_pb2.Int64List( - value=[f])), - "record": - feature_pb2.Feature(int64_list=feature_pb2.Int64List( - value=[r])), - "keywords": - feature_pb2.Feature(bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) return example.SerializeToString() def _get_keywords(self, f, r): @@ -312,24 +329,35 @@ class ReadBatchFeaturesTest(test.TestCase): writer.close() return filenames - def _next_actual_batch(self, sess): - file_op = self.outputs["file"] - keywords_indices_op = self.outputs["keywords"].indices - keywords_values_op = self.outputs["keywords"].values - keywords_dense_shape_op = self.outputs["keywords"].dense_shape - record_op = self.outputs["record"] + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] return sess.run([ file_op, keywords_indices_op, keywords_values_op, keywords_dense_shape_op, record_op ]) - def _next_expected_batch(self, file_indices, batch_size, num_epochs): + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): def _next_record(file_indices): for j in file_indices: for i in range(self._num_records): yield j, i + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + file_batch = [] keywords_batch_indices = [] keywords_batch_values = [] @@ -337,15 +365,19 @@ class ReadBatchFeaturesTest(test.TestCase): record_batch = [] batch_index = 0 for _ in range(num_epochs): - for record in _next_record(file_indices): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: f = record[0] r = record[1] file_batch.append(f) record_batch.append(r) keywords = self._get_keywords(f, r) keywords_batch_values.extend(keywords) - keywords_batch_indices.extend([[batch_index, i] - for i in range(len(keywords))]) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) batch_index += 1 keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) if len(file_batch) == batch_size: @@ -365,14 +397,41 @@ class ReadBatchFeaturesTest(test.TestCase): [len(file_batch), keywords_batch_max_len], record_batch ] - def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1): + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): if file_index is not None: file_indices = [file_index] else: file_indices = range(self._num_files) - for expected_batch in self._next_expected_batch(file_indices, batch_size, - num_epochs): + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): actual_batch = self._next_actual_batch(sess) for i in range(len(expected_batch)): self.assertAllEqual(expected_batch[i], actual_batch[i]) @@ -418,9 +477,10 @@ class ReadBatchFeaturesTest(test.TestCase): "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } - dataset = (core_readers.TFRecordDataset(self.test_filenames) - .map(lambda x: parsing_ops.parse_single_example(x, features)) - .repeat(10).batch(2)) + dataset = ( + core_readers.TFRecordDataset(self.test_filenames) + .map(lambda x: parsing_ops.parse_single_example(x, features)) + .repeat(10).batch(2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer next_element = iterator.get_next() @@ -435,6 +495,596 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testReadWithFusedShuffleRepeatDataset(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + for batch_size in [1, 2]: + # Test that shuffling with same seed produces the same result. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs2 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + # Test that shuffling with different seeds produces a different order. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs2 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=15) + all_equal = True + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + def testParallelReadersAndParsers(self): + num_epochs = 5 + for batch_size in [1, 2]: + for reader_num_threads in [2, 4]: + for parser_num_threads in [2, 4]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + self.outputs = self._read_batch_features( + filenames=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads) + self._verify_records( + sess, + batch_size, + num_epochs=num_epochs, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess) + + def testDropFinalBatch(self): + for batch_size in [1, 2]: + for num_epochs in [1, 10]: + with ops.Graph().as_default(): + # Basic test: read from file 0. + self.outputs = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + drop_final_batch=True) + for _, tensor in self.outputs.items(): + if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. + self.assertEqual(tensor.shape[0], batch_size) + + +class MakeCsvDatasetTest(test.TestCase): + + COLUMN_TYPES = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] + COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))] + DEFAULT_VALS = [[], [], [], [], ["NULL"]] + DEFAULTS = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.float64), + constant_op.constant(["NULL"], dtype=dtypes.string) + ] + LABEL = COLUMNS[0] + + def setUp(self): + super(MakeCsvDatasetTest, self).setUp() + self._num_files = 2 + self._num_records = 11 + self._test_filenames = self._create_files() + + def _csv_values(self, fileno, recordno): + return [ + fileno, + recordno, + fileno * recordno * 0.5, + fileno * recordno + 0.5, + "record %d" % recordno if recordno % 2 == 1 else "", + ] + + def _write_file(self, filename, rows): + for i in range(len(rows)): + if isinstance(rows[i], list): + rows[i] = ",".join(str(v) if v is not None else "" for v in rows[i]) + fn = os.path.join(self.get_temp_dir(), filename) + f = open(fn, "w") + f.write("\n".join(rows)) + f.close() + return fn + + def _create_file(self, fileno, header=True, comment=True): + rows = [] + if header: + rows.append(self.COLUMNS) + for recno in range(self._num_records): + rows.append(self._csv_values(fileno, recno)) + if comment: + rows.append("# Some comment goes here. Ignore me.") + return self._write_file("csv_file%d.csv" % fileno, rows) + + def _create_files(self): + filenames = [] + for i in range(self._num_files): + filenames.append(self._create_file(i)) + return filenames + + def _make_csv_dataset( + self, + filenames, + defaults, + column_names=COLUMNS, + label_name=LABEL, + select_cols=None, + batch_size=1, + num_epochs=1, + shuffle=False, + shuffle_seed=None, + header=True, + comment="#", + na_value="", + default_float_type=dtypes.float32, + ): + return readers.make_csv_dataset( + filenames, + batch_size=batch_size, + column_names=column_names, + column_defaults=defaults, + label_name=label_name, + num_epochs=num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + header=header, + comment=comment, + na_value=na_value, + default_float_type=default_float_type, + select_columns=select_cols, + ) + + def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults): + features = {col: list() for col in self.COLUMNS} + for _ in range(num_epochs): + for i in file_indices: + for j in range(self._num_records): + values = self._csv_values(i, j) + for n, v in enumerate(values): + if v == "": # pylint: disable=g-explicit-bool-comparison + values[n] = defaults[n][0] + values[-1] = values[-1].encode("utf-8") + + # Regroup lists by column instead of row + for n, col in enumerate(self.COLUMNS): + features[col].append(values[n]) + if len(list(features.values())[0]) == batch_size: + yield features + features = {col: list() for col in self.COLUMNS} + + def _run_actual_batch(self, outputs, sess): + features, labels = sess.run(outputs) + batch = [features[k] for k in self.COLUMNS if k != self.LABEL] + batch.append(labels) + return batch + + def _verify_records( + self, + sess, + dataset, + file_indices, + defaults=tuple(DEFAULT_VALS), + label_name=LABEL, + batch_size=1, + num_epochs=1, + ): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + for expected_features in self._next_actual_batch(file_indices, batch_size, + num_epochs, defaults): + actual_features = sess.run(get_next) + + if label_name is not None: + expected_labels = expected_features.pop(label_name) + # Compare labels + self.assertAllEqual(expected_labels, actual_features[1]) + actual_features = actual_features[0] # Extract features dict from tuple + + for k in expected_features.keys(): + # Compare features + self.assertAllEqual(expected_features[k], actual_features[k]) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testMakeCSVDataset(self): + defaults = self.DEFAULTS + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Basic test: read from file 0. + dataset = self._make_csv_dataset(self._test_filenames[0], defaults) + self._verify_records(sess, dataset, [0]) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Basic test: read from file 1. + dataset = self._make_csv_dataset(self._test_filenames[1], defaults) + self._verify_records(sess, dataset, [1]) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. + dataset = self._make_csv_dataset(self._test_filenames, defaults) + self._verify_records(sess, dataset, range(self._num_files)) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, defaults, batch_size=2, num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def testMakeCSVDataset_withBadColumns(self): + """Tests that exception is raised when input is malformed. + """ + dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1] + defaults = self.DEFAULTS + + # Duplicate column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, column_names=dupe_columns) + + # Label key not one of column names + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, defaults, label_name="not_a_real_label") + + def testMakeCSVDataset_withNoLabel(self): + """Tests that CSV datasets can be created when no label is specified. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Make sure this works with no label key supplied. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=2, + num_epochs=10, + label_name=None) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + label_name=None) + + def testMakeCSVDataset_withNoComments(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), comment=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + comment=None, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) + + def testMakeCSVDataset_withNoHeader(self): + """Tests that datasets can be created from CSV files with no header line. + """ + defaults = self.DEFAULTS + file_without_header = self._create_file( + len(self._test_filenames), header=False) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + file_without_header, + defaults, + batch_size=2, + num_epochs=10, + header=False, + ) + self._verify_records( + sess, + dataset, + [len(self._test_filenames)], + batch_size=2, + num_epochs=10, + ) + + def testMakeCSVDataset_withTypes(self): + """Tests that defaults can be a dtype instead of a Tensor for required vals. + """ + defaults = [d for d in self.COLUMN_TYPES[:-1]] + defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string)) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset(self._test_filenames, defaults) + self._verify_records(sess, dataset, range(self._num_files)) + + def testMakeCSVDataset_withNoColNames(self): + """Tests that datasets can be created when column names are not specified. + + In that case, we should infer the column names from the header lines. + """ + defaults = self.DEFAULTS + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Read from both files. Exercise the `batch` and `num_epochs` parameters + # of make_csv_dataset and make sure they work. + dataset = self._make_csv_dataset( + self._test_filenames, + defaults, + column_names=None, + batch_size=2, + num_epochs=10) + self._verify_records( + sess, dataset, range(self._num_files), batch_size=2, num_epochs=10) + + def testMakeCSVDataset_withTypeInferenceMismatch(self): + # Test that error is thrown when num fields doesn't match columns + with self.assertRaises(ValueError): + self._make_csv_dataset( + self._test_filenames, + column_names=self.COLUMNS + ["extra_name"], + defaults=None, + batch_size=2, + num_epochs=10) + + def testMakeCSVDataset_withTypeInference(self): + """Tests that datasets can be created when no defaults are specified. + + In that case, we should infer the types from the first N records. + """ + # Test that it works with standard test files (with comments, header, etc) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + self._test_filenames, defaults=None, batch_size=2, num_epochs=10) + self._verify_records( + sess, + dataset, + range(self._num_files), + batch_size=2, + num_epochs=10, + defaults=[[], [], [], [], [""]]) + + # Test on a deliberately tricky file + fn = os.path.join(self.get_temp_dir(), "file.csv") + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, + dtypes.string, dtypes.string + ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[None, None, None, "NAN", "", + "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""], + ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']] + expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""], + [123, 2, 2**64, 123.4, "", "cd,efg"]] + for row in expected: + row[-1] = row[-1].encode("utf-8") # py3 expects byte strings + row[-2] = row[-2].encode("utf-8") # py3 expects byte strings + self._write_file("file.csv", [col_names] + rows) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + na_value="NAN", + default_float_type=dtypes.float32, + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + assert features["col%d" % i].dtype == expected_dtypes[i] + for i in range(len(rows)): + assert sess.run(features) == dict(zip(col_names, expected[i])) + + # With float64 as default type for floats + expected_dtypes = [ + dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.string, dtypes.string + ] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + na_value="NAN", + default_float_type=dtypes.float64, + ) + features = dataset.make_one_shot_iterator().get_next() + # Check that types match + for i in range(len(expected_dtypes)): + self.assertAllEqual(features["col%d" % i].dtype, expected_dtypes[i]) + for i in range(len(rows)): + self.assertAllEqual( + sess.run(features), dict(zip(col_names, expected[i]))) + + def testMakeCSVDataset_withSelectColsError(self): + data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + col_names = ["col%d" % i for i in range(5)] + fn = self._write_file("file.csv", [col_names] + data) + with self.assertRaises(ValueError): + # Mismatch in number of defaults and number of columns selected, + # should raise an error + self._make_csv_dataset( + fn, + defaults=[[0]] * 5, + column_names=col_names, + label_name=None, + select_cols=[1, 3]) + with self.assertRaises(ValueError): + # Invalid column name should raise an error + self._make_csv_dataset( + fn, + defaults=[[0]], + column_names=col_names, + label_name=None, + select_cols=["invalid_col_name"]) + + def testMakeCSVDataset_withSelectCols(self): + data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] + col_names = ["col%d" % i for i in range(5)] + fn = self._write_file("file.csv", [col_names] + data) + # If select_cols is specified, should only yield a subset of columns + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=[[0], [0]], + column_names=col_names, + label_name=None, + select_cols=[1, 3]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + # Can still do default inference with select_cols + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=col_names, + label_name=None, + select_cols=[1, 3]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + # Can still do column name inference + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + select_cols=[1, 3]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + # Can specify column names instead of indices + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = self._make_csv_dataset( + fn, + defaults=None, + column_names=None, + label_name=None, + select_cols=[col_names[1], col_names[3]]) + expected = [[1, 3], [6, 8]] + features = dataset.make_one_shot_iterator().get_next() + for i in range(len(data)): + self.assertAllEqual( + sess.run(features), + dict(zip([col_names[1], col_names[3]], expected[i]))) + + def testMakeCSVDataset_withShuffle(self): + total_records = self._num_files * self._num_records + defaults = self.DEFAULTS + for batch_size in [1, 2]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Test that shuffling with the same seed produces the same result + dataset1 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + dataset2 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + # Test that shuffling with a different seed produces different results + dataset1 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + dataset2 = self._make_csv_dataset( + self._test_filenames, + defaults, + batch_size=batch_size, + shuffle=True, + shuffle_seed=6) + outputs1 = dataset1.make_one_shot_iterator().get_next() + outputs2 = dataset2.make_one_shot_iterator().get_next() + all_equal = False + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 3c7b46629edb13459766b5ef3f392e8d00ad4db8..5f47dcb33999119a690bd633f0c97a12a1ae1c84 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -21,7 +21,10 @@ import numpy as np from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -45,12 +48,10 @@ class ResampleTest(test.TestCase): target_dist=target_dist, initial_dist=initial_dist, class_func=lambda c, _: c, - seed=27)).make_initializable_iterator()) - init_op = iterator.initializer + seed=27)).make_one_shot_iterator()) get_next = iterator.get_next() with self.test_session() as sess: - sess.run(init_op) returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -70,6 +71,43 @@ class ResampleTest(test.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) + def testRandomClasses(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Apply a random mapping that preserves the data distribution. + def _remap_fn(_): + return math_ops.cast(random_ops.random_uniform([1]) * num_classes, + dtypes.int32)[0] + dataset = dataset.map(_remap_fn) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + with self.assertRaises(errors.OutOfRangeError): + while True: + returned.append(sess.run(get_next)) + + classes, _ = zip(*returned) + bincount = np.bincount( + np.array(classes), + minlength=num_classes).astype(np.float32) / len(classes) + + self.assertAllClose(target_dist, bincount, atol=1e-2) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 36ddf3004237ed042f21d691d83eafbaa20621e6..d0cb203a3afd2775756c8542a1e86faedc5cee53 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -47,6 +47,11 @@ class SequenceDatasetSerializationTest( # Skip nothing self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + def testInvalidSkip(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -69,6 +74,11 @@ class SequenceDatasetSerializationTest( # Take nothing self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + def testInvalidTake(self): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -100,6 +110,12 @@ class SequenceDatasetSerializationTest( # Test repeat empty dataset self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + def testInvalidRepeat(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0), + None, 0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..33c48e20bea53b88d69a59e715af38b22dd2cbd4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -0,0 +1,242 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import sliding +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class SlideDatasetTest(test.TestCase): + + def testSlideDataset(self): + """Test an dataset that maps a TF function across its input elements.""" + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + window_size = array_ops.placeholder(dtypes.int64, shape=[]) + stride = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> + # RepeatDataset(count) -> _SlideDataset(window_size, stride). + iterator = (dataset_ops.Dataset.from_tensor_slices(components) + .map(_map_fn) + .repeat(count) + .apply(sliding.sliding_window_batch(window_size, stride)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + # Slide over a finite input, where the window_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7}) + # Same formula with convolution layer. + num_batches = (20 * 7 - 14) // 7 + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*7 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over a finite input, where the window_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9}) + + num_batches = (20 * 7 - 17) // 9 + 1 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(17): + self.assertAllEqual(component[(i*9 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over a finite input, which is less than window_size, + # should fail straight away. + sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Slide over an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, window_size: 8, stride: 4}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty window_size should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 0, stride: 0}) + + # Invalid stride should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0}) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3}) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5}) + + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSlideSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSlideSparseWithDifferentDenseShapes(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=array_ops.expand_dims( + math_ops.range(i, dtype=dtypes.int64), 1), + values=array_ops.fill([math_ops.to_int32(i)], i), + dense_shape=[i]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).apply( + sliding.sliding_window_batch(5, 3)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + num_batches = (10 - 5) // 3 + 1 + for i in range(num_batches): + actual = sess.run(get_next) + expected_indices = [] + expected_values = [] + for j in range(5): + for k in range(i * 3 + j): + expected_indices.append([j, k]) + expected_values.append(i * 3 + j) + expected = sparse_tensor.SparseTensorValue( + indices=expected_indices, + values=expected_values, + dense_shape=[5, i * 3 + 5 - 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedSlideSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .apply(sliding.sliding_window_batch(4, 2)) + .apply(sliding.sliding_window_batch(3, 1)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + # Slide: 1st batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], + [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + # Slide: 2nd batch. + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], + [2, 0, 0], [2, 1, 0], [2, 2, 0], [2, 3, 0]], + values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], + dense_shape=[3, 4, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSlideShapeError(self): + + def generator(): + yield [1.0, 2.0, 3.0] + yield [4.0, 5.0, 6.0] + yield [7.0, 8.0, 9.0, 10.0] + + iterator = (dataset_ops.Dataset.from_generator(generator, dtypes.float32, + output_shapes=[None]) + .apply(sliding.sliding_window_batch(3, 1)) + .make_initializable_iterator()) + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot batch tensors with different shapes in component 0. " + r"First element had shape \[3\] and element 2 had shape \[4\]."): + sess.run(next_element) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 07bdf920446e953c2a1abaf495d2e9e1256106fd..c3a7f291c59a72dc6057f7e1c51d5ac78334176b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -218,6 +218,14 @@ class StatsDatasetSerializationTest( lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( stats_ops.bytes_produced_stats("bytes_produced")) + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + def testBytesStatsDatasetSaveableCore(self): num_outputs = 100 self.run_core_tests( @@ -235,6 +243,14 @@ class StatsDatasetSerializationTest( return dataset_ops.Dataset.range(num_elements).apply( stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, 'Shape must be rank 0 but is rank 1'): + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + def testLatencyStatsDatasetSaveableCore(self): num_outputs = 100 diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9167cb3379bba5cb1ba76a96549395c45dca9e35 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import numpy as np + +from tensorflow.contrib.data.python.ops import threadpool +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class OverrideThreadpoolDatasetTest(test.TestCase): + + def testNumThreads(self): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + for num_threads in [1, 2, 4, 8, 16]: + + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, display_name="private_thread_pool_%d" % num_threads)) + + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index b488357f226d0922bba3799cc1f4b5c75e2e8328..e00f2304cc415e45b76f4f53c51bca1cfe4ac114 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -12,18 +12,26 @@ load( load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") py_library( - name = "dataset_ops", - srcs = [ - "counter.py", - "get_single_element.py", + name = "counter", + srcs = ["counter.py"], + srcs_version = "PY2AND3", + deps = [ + ":scan_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", ], +) + +py_library( + name = "get_single_element", + srcs = ["get_single_element.py"], srcs_version = "PY2AND3", deps = [ - ":transformation_ops", "//tensorflow/python:dataset_ops_gen", - "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) @@ -66,18 +74,25 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":dataset_ops", + ":batching", + ":interleave_ops", + ":shuffle_ops", + "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", + "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", ], ) @@ -88,46 +103,180 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":random_ops", - ":transformation_ops", "//tensorflow/python/data/ops:dataset_ops", ], ) py_library( - name = "transformation_ops", - srcs = [ - "batching.py", - "enumerate_ops.py", - "error_ops.py", - "grouping.py", - "interleave_ops.py", - "resampling.py", - "scan_ops.py", - "stats_ops.py", - "unique.py", + name = "batching", + srcs = ["batching.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "enumerate_ops", + srcs = ["enumerate_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", ], +) + +py_library( + name = "error_ops", + srcs = ["error_ops.py"], srcs_version = "PY2AND3", deps = [ ":contrib_op_loader", ":gen_dataset_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "grouping", + srcs = ["grouping.py"], + srcs_version = "PY2AND3", + deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", - "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "interleave_ops", + srcs = ["interleave_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + ":random_ops", + "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "resampling", + srcs = ["resampling.py"], + srcs_version = "PY2AND3", + deps = [ + ":batching", + ":scan_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:logging_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_library( + name = "scan_ops", + srcs = ["scan_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "sliding", + srcs = ["sliding.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "stats_ops", + srcs = ["stats_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +py_library( + name = "threadpool", + srcs = ["threadpool.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + "//tensorflow/python/eager:context", + ], +) + +py_library( + name = "unique", + srcs = [ + "unique.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", - "//third_party/py/numpy", ], ) @@ -167,17 +316,35 @@ py_library( srcs = ["prefetching_ops.py"], deps = [ ":contrib_op_loader", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +py_library( + name = "dataset_ops", + deps = [ + ":batching", + ":counter", + ":enumerate_ops", + ":error_ops", + ":get_single_element", + ":grouping", + ":interleave_ops", + ":prefetching_ops", + ":readers", + ":resampling", + ":scan_ops", + ":shuffle_ops", + ":sliding", + ":stats_ops", + ":threadpool", + ":unique", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 6eb512dec67cb7b9c8c4518d03aee0b436205f9a..28db949da9e371bfa27fa847faee88f0366699ba 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse @@ -345,16 +346,62 @@ class _RestructuredDataset(dataset_ops.Dataset): return self._output_shapes +def assert_element_shape(expected_shapes): + """Assert the shape of this `Dataset`. + + ```python + shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)] + result = dataset.apply(tf.contrib.data.assert_element_shape(shapes)) + print(result.output_shapes) # ==> "((16, 256), )" + ``` + + If dataset shapes and expected_shape, are fully defined, assert they match. + Otherwise, add assert op that will validate the shapes when tensors are + evaluated, and set shapes on tensors, respectively. + + Args: + expected_shapes: A nested structure of `tf.TensorShape` objects. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _check_shape(*elements): + flatten_tensors = nest.flatten(elements) + flatten_shapes = nest.flatten(expected_shapes) + checked_tensors = [ + with_shape(shape, tensor) + for shape, tensor in zip(flatten_shapes, flatten_tensors) + ] + return nest.pack_sequence_as(elements, checked_tensors) + + def _apply_fn(dataset): + return _RestructuredDataset( + dataset.map(_check_shape), + dataset.output_types, + output_shapes=expected_shapes, + output_classes=dataset.output_classes) + + return _apply_fn + + class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size = ops.convert_to_tensor( + self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches = ops.convert_to_tensor( + self._num_parallel_batches_t = ops.convert_to_tensor( num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._drop_remainder_t = ops.convert_to_tensor( + drop_remainder, dtype=dtypes.bool, name="drop_remainder") + + self._batch_size = batch_size + self._drop_remainder = drop_remainder def _as_variant_tensor(self): # pylint: disable=protected-access @@ -363,8 +410,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): input_resource, self._map_func.captured_inputs, f=self._map_func, - batch_size=self._batch_size, - num_parallel_batches=self._num_parallel_batches, + batch_size=self._batch_size_t, + num_parallel_batches=self._num_parallel_batches_t, + drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( @@ -373,9 +421,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): @property def output_shapes(self): + dim = self._batch_size if self._drop_remainder else None return nest.pack_sequence_as(self._output_shapes, [ - tensor_shape.vector(tensor_util.constant_value( - self._batch_size)).concatenate(s) + tensor_shape.vector(dim).concatenate(s) for s in nest.flatten(self._output_shapes) ]) @@ -384,7 +432,10 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): return self._output_types -def map_and_batch(map_func, batch_size, num_parallel_batches=1): +def map_and_batch(map_func, + batch_size, + num_parallel_batches=1, + drop_remainder=False): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -404,6 +455,9 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values can increase contention if CPU is scarce. + drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the + last batch should be dropped in case its size is smaller than desired; + the default behavior is not to drop the smaller batch. Returns: A `Dataset` transformation function, which can be passed to @@ -412,6 +466,6 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches) + num_parallel_batches, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py index 63226fe78163c59025623a362d17c400fbe57c67..6ef65f9624601286691505a795a86dd6226eead1 100644 --- a/tensorflow/contrib/data/python/ops/counter.py +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import ops def Counter(start=0, step=1, dtype=dtypes.int64): - """Creates a `Dataset` of a `step`-separated count startin from `start`. + """Creates a `Dataset` that counts from `start` in steps of size `step`. For example: @@ -38,12 +38,13 @@ def Counter(start=0, step=1, dtype=dtypes.int64): ``` Args: - start: starting value for count. - step: step size. - dtype: counter data type. + start: (Optional.) The starting value for the counter. Defaults to 0. + step: (Optional.) The step size for the counter. Defaults to 1. + dtype: (Optional.) The data type for counter elements. Defaults to + `tf.int64`. Returns: - A `Dataset` of scalar elements. + A `Dataset` of scalar `dtype` elements. """ with ops.name_scope("counter"): start = ops.convert_to_tensor(start, dtype=dtype, name="start") diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py deleted file mode 100644 index ff15c4451ad987bcd77dbdd022a1c070056c47e1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ /dev/null @@ -1,691 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gen_io_ops -from tensorflow.python.util import deprecation - - -class Dataset(dataset_ops.Dataset): - """Represents a potentially large set of elements. - - A `Dataset` can be used to represent an input pipeline as a - collection of elements (nested structures of tensors) and a "logical - plan" of transformations that act on those elements. - """ - - def __init__(self, dataset): - super(Dataset, self).__init__() - self._dataset = dataset - - @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") - def make_dataset_resource(self): - return self._as_variant_tensor() - - def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access - - @property - def output_classes(self): - return self._dataset.output_classes - - @property - def output_shapes(self): - return self._dataset.output_shapes - - @property - def output_types(self): - return self._dataset.output_types - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") - def from_tensors(tensors): - """Creates a `Dataset` with a single element, comprising the given tensors. - - Args: - tensors: A nested structure of tensors. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TensorDataset(tensors)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") - def from_tensor_slices(tensors): - """Creates a `Dataset` whose elements are slices of the given tensors. - - Args: - tensors: A nested structure of tensors, each having the same size in the - 0th dimension. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TensorSliceDataset(tensors)) - - @staticmethod - @deprecation.deprecated(None, - "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") - def from_sparse_tensor_slices(sparse_tensor): - """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. - - Args: - sparse_tensor: A `tf.SparseTensor`. - - Returns: - A `Dataset` of rank-(N-1) sparse tensors. - """ - return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") - def from_generator(generator, output_types, output_shapes=None): - """Creates a `Dataset` whose elements are generated by `generator`. - - The `generator` argument must be a callable object that returns - an object that support the `iter()` protocol (e.g. a generator function). - The elements generated by `generator` must be compatible with the given - `output_types` and (optional) `output_shapes` arguments. - - For example: - - ```python - import itertools - - def gen(): - for i in itertools.count(1): - yield (i, [1] * i) - - ds = Dataset.from_generator( - gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) - value = ds.make_one_shot_iterator().get_next() - - sess.run(value) # (1, array([1])) - sess.run(value) # (2, array([1, 1])) - ``` - - Args: - generator: A callable object that takes no arguments and returns an - object that supports the `iter()` protocol. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element yielded by `generator`. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` - objects corresponding to each component of an element yielded by - `generator`. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.Dataset.from_generator( - generator, output_types, output_shapes)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") - def range(*args): - """Creates a `Dataset` of a step-separated range of values. - - For example: - - ```python - Dataset.range(5) == [0, 1, 2, 3, 4] - Dataset.range(2, 5) == [2, 3, 4] - Dataset.range(1, 5, 2) == [1, 3] - Dataset.range(1, 5, -2) == [] - Dataset.range(5, 1) == [] - Dataset.range(5, 1, -2) == [5, 3] - ``` - - Args: - *args: follow same semantics as python's xrange. - len(args) == 1 -> start = 0, stop = args[0], step = 1 - len(args) == 2 -> start = args[0], stop = args[1], step = 1 - len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] - - Returns: - A `RangeDataset`. - - Raises: - ValueError: if len(args) == 0. - """ - return Dataset(dataset_ops.RangeDataset(*args)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") - def zip(datasets): - """Creates a `Dataset` by zipping together the given datasets. - - This method has similar semantics to the built-in `zip()` function - in Python, with the main difference being that the `datasets` - argument can be an arbitrary nested structure of `Dataset` objects. - For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { 4, 5, 6 } - c = { (7, 8), (9, 10), (11, 12) } - d = { 13, 14 } - - # The nested structure of the `datasets` argument determines the - # structure of elements in the resulting dataset. - Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } - Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } - - # The `datasets` argument may contain an arbitrary number of - # datasets. - Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), - (2, 5, (9, 10)), - (3, 6, (11, 12)) } - - # The number of elements in the resulting dataset is the same as - # the size of the smallest dataset in `datasets`. - Dataset.zip((a, d)) == { (1, 13), (2, 14) } - ``` - - Args: - datasets: A nested structure of datasets. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ZipDataset(datasets)) - - def concatenate(self, dataset): - """Creates a `Dataset` by concatenating given dataset with this dataset. - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { 4, 5, 6, 7 } - - # Input dataset and dataset to be concatenated should have same - # nested structures and output types. - # c = { (8, 9), (10, 11), (12, 13) } - # d = { 14.0, 15.0, 16.0 } - # a.concatenate(c) and a.concatenate(d) would result in error. - - a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } - ``` - - Args: - dataset: `Dataset` to be concatenated. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset)) - - def prefetch(self, buffer_size): - """Creates a `Dataset` that prefetches elements from this dataset. - - Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - maximum number elements that will be buffered when prefetching. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") - def list_files(file_pattern): - """A dataset of all files matching a pattern. - - Example: - If we had the following files on our filesystem: - - /path/to/dir/a.txt - - /path/to/dir/b.py - - /path/to/dir/c.py - If we pass "/path/to/dir/*.py" as the directory, the dataset would - produce: - - /path/to/dir/b.py - - /path/to/dir/c.py - - Args: - file_pattern: A string or scalar string `tf.Tensor`, representing - the filename pattern that will be matched. - - Returns: - A `Dataset` of strings corresponding to file names. - """ - return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) - - def repeat(self, count=None): - """Repeats this dataset `count` times. - - Args: - count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - number of times the elements of this dataset should be repeated. The - default behavior (if `count` is `None` or `-1`) is for the elements to - be repeated indefinitely. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") - def enumerate(self, start=0): - """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" - - return self.apply(enumerate_ops.enumerate_dataset(start)) - - def shuffle(self, buffer_size, seed=None): - """Randomly shuffles the elements of this dataset. - - Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - number of elements from this dataset from which the new - dataset will sample. - seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed)) - - def cache(self, filename=""): - """Caches the elements in this dataset. - - Args: - filename: A `tf.string` scalar `tf.Tensor`, representing the name of a - directory on the filesystem to use for caching tensors in this Dataset. - If a filename is not provided, the dataset will be cached in memory. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.CacheDataset(self._dataset, filename)) - - def take(self, count): - """Creates a `Dataset` with at most `count` elements from this dataset. - - Args: - count: A `tf.int64` scalar `tf.Tensor`, representing the number of - elements of this dataset that should be taken to form the new dataset. - If `count` is -1, or if `count` is greater than the size of this - dataset, the new dataset will contain all elements of this dataset. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TakeDataset(self._dataset, count)) - - def skip(self, count): - """Creates a `Dataset` that skips `count` elements from this dataset. - - Args: - count: A `tf.int64` scalar `tf.Tensor`, representing the number - of elements of this dataset that should be skipped to form the - new dataset. If `count` is greater than the size of this - dataset, the new dataset will contain no elements. If `count` - is -1, skips the entire dataset. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.SkipDataset(self._dataset, count)) - - def shard(self, num_shards, index): - """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. - - This dataset operator is very useful when running distributed training, as - it allows each worker to read a unique subset. - - When reading a single input file, you can skip elements as follows: - - ```python - d = tf.data.TFRecordDataset(FLAGS.input_file) - d = d.shard(FLAGS.num_workers, FLAGS.worker_index) - d = d.repeat(FLAGS.num_epochs) - d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) - ``` - - Important caveats: - - - Be sure to shard before you use any randomizing operator (such as - shuffle). - - Generally it is best if the shard operator is used early in the dataset - pipeline. For example, when reading from a set of TFRecord files, shard - before converting the dataset to input samples. This avoids reading every - file on every worker. The following is an example of an efficient - sharding strategy within a complete pipeline: - - ```python - d = tf.data.Dataset.list_files(FLAGS.pattern) - d = d.shard(FLAGS.num_workers, FLAGS.worker_index) - d = d.repeat(FLAGS.num_epochs) - d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.interleave(tf.data.TFRecordDataset, - cycle_length=FLAGS.num_readers, block_length=1) - d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) - ``` - - Args: - num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of - shards operating in parallel. - index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. - - Returns: - A `Dataset`. - - Raises: - ValueError: if `num_shards` or `index` are illegal values. Note: error - checking is done on a best-effort basis, and aren't guaranteed to be - caught upon dataset creation. (e.g. providing in a placeholder tensor - bypasses the early checking, and will instead result in an error during - a session.run call.) - """ - return Dataset(self._dataset.shard(num_shards, index)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.") - def ignore_errors(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" - - return self.apply(error_ops.ignore_errors()) - - def batch(self, batch_size): - """Combines consecutive elements of this dataset into batches. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size)) - - def padded_batch(self, batch_size, padded_shapes, padding_values=None): - """Combines consecutive elements of this dataset into padded batches. - - Like `Dataset.dense_to_sparse_batch()`, this method combines - multiple consecutive elements of this dataset, which might have - different shapes, into a single element. The tensors in the - resulting element have an additional outer dimension, and are - padded to the respective shape in `padded_shapes`. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - padded_shapes: A nested structure of `tf.TensorShape` or - `tf.int64` vector tensor-like objects representing the shape - to which the respective component of each input element should - be padded prior to batching. Any unknown dimensions - (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a - tensor-like object) will be padded to the maximum size of that - dimension in each batch. - padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`, representing the padding values to use for the - respective components. Defaults are `0` for numeric types and - the empty string for string types. - - Returns: - A `Dataset`. - """ - return Dataset( - dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, - padding_values)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") - def dense_to_sparse_batch(self, batch_size, row_shape): - """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" - - return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.group_by_window())`.") - def group_by_window(self, key_func, reduce_func, window_size): - """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" - - return self.apply( - grouping.group_by_window(key_func, reduce_func, window_size)) - - @deprecation.deprecated_args( - None, - "Replace `num_threads=T` with `num_parallel_calls=T`. Replace " - "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", - "num_threads", "output_buffer_size") - def map(self, - map_func, - num_threads=None, - output_buffer_size=None, - num_parallel_calls=None): - """Maps `map_func` across this dataset. - - Args: - map_func: A function mapping a nested structure of tensors (having - shapes and types defined by `self.output_shapes` and - `self.output_types`) to another nested structure of tensors. - num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead. - output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`, - representing the maximum number of processed elements that will be - buffered. - num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, - representing the number elements to process in parallel. If not - specified, elements will be processed sequentially. - - Returns: - A `Dataset`. - """ - if num_threads is None and num_parallel_calls is None: - ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func)) - else: - if num_threads is None: - ret = Dataset( - dataset_ops.ParallelMapDataset(self._dataset, map_func, - num_parallel_calls)) - else: - ret = Dataset( - dataset_ops.ParallelMapDataset(self._dataset, map_func, - num_threads)) - if output_buffer_size is not None: - ret = ret.prefetch(output_buffer_size) - return ret - - def flat_map(self, map_func): - """Maps `map_func` across this dataset and flattens the result. - - Args: - map_func: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - `Dataset`. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func)) - - def interleave(self, map_func, cycle_length, block_length=1): - """Maps `map_func` across this dataset, and interleaves the results. - - For example, you can use `Dataset.interleave()` to process many input files - concurrently: - - ```python - # Preprocess 4 files concurrently, and interleave blocks of 16 records from - # each file. - filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] - dataset = (Dataset.from_tensor_slices(filenames) - .interleave(lambda x: - TextLineDataset(x).map(parse_fn, num_parallel_calls=1), - cycle_length=4, block_length=16)) - ``` - - The `cycle_length` and `block_length` arguments control the order in which - elements are produced. `cycle_length` controls the number of input elements - that are processed concurrently. If you set `cycle_length` to 1, this - transformation will handle one input element at a time, and will produce - identical results = to @{tf.data.Dataset.flat_map}. In general, - this transformation will apply `map_func` to `cycle_length` input elements, - open iterators on the returned `Dataset` objects, and cycle through them - producing `block_length` consecutive elements from each iterator, and - consuming the next input element each time it reaches the end of an - iterator. - - For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3, 4, 5 } - - # NOTE: New lines indicate "block" boundaries. - a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), - cycle_length=2, block_length=4) == { - 1, 1, 1, 1, - 2, 2, 2, 2, - 1, 1, - 2, 2, - 3, 3, 3, 3, - 4, 4, 4, 4, - 3, 3, - 4, 4, - 5, 5, 5, 5, - 5, 5, - } - ``` - - NOTE: The order of elements yielded by this transformation is - deterministic, as long as `map_func` is a pure function. If - `map_func` contains any stateful operations, the order in which - that state is accessed is undefined. - - Args: - map_func: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - `Dataset`. - cycle_length: The number of elements from this dataset that will be - processed concurrently. - block_length: The number of consecutive elements to produce from each - input element before cycling to another input element. - - Returns: - A `Dataset`. - """ - return Dataset( - dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, - block_length)) - - @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") - def unbatch(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" - - return self.apply(batching.unbatch()) - - def filter(self, predicate): - """Filters this dataset according to `predicate`. - - Args: - predicate: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - scalar `tf.bool` tensor. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.FilterDataset(self._dataset, predicate)) - - def apply(self, transformation_func): - """Apply a transformation function to this dataset. - - `apply` enables chaining of custom `Dataset` transformations, which are - represented as functions that take one `Dataset` argument and return a - transformed `Dataset`. - - For example: - - ``` - dataset = (dataset.map(lambda x: x ** 2) - .(group_by_window(key_func, reduce_func, window_size)) - .map(lambda x: x ** 3)) - ``` - - Args: - transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. - - Returns: - The `Dataset` returned by applying `transformation_func` to this dataset. - """ - dataset = transformation_func(self) - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`transformation_func` must return a Dataset.") - return Dataset(dataset) - - -def get_single_element(dataset): - """Returns the single element in `dataset` as a nested structure of tensors. - - This function enables you to use a @{tf.data.Dataset} in a stateless - "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. - This can be useful when your preprocessing transformations are expressed - as a `Dataset`, and you want to use the transformation at serving time. - For example: - - ```python - input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) - - def preprocessing_fn(input_str): - # ... - return image, label - - dataset = (tf.data.Dataset.from_tensor_slices(input_batch) - .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) - .batch(BATCH_SIZE)) - - image_batch, label_batch = tf.contrib.data.get_single_element(dataset) - ``` - - Args: - dataset: A @{tf.data.Dataset} object containing a single element. - - Returns: - A nested structure of @{tf.Tensor} objects, corresponding to the single - element of `dataset`. - - Raises: - TypeError: if `dataset` is not a `tf.data.Dataset` object. - InvalidArgumentError (at runtime): if `dataset` does not contain exactly - one element. - """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( - dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index a817b45b71b608810a9d7536ec123ab84f7cdc3b..3a07df572748e464284f580d67e3a664e71acdfe 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.ops import gen_dataset_ops @@ -59,9 +60,14 @@ def get_single_element(dataset): """ if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( + + nested_ret = nest.pack_sequence_as( + dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) + output_types=nest.flatten(sparse.as_dense_types( + dataset.output_types, dataset.output_classes)), + output_shapes=nest.flatten(sparse.as_dense_shapes( + dataset.output_shapes, dataset.output_classes)))) + return sparse.deserialize_sparse_tensors( + nested_ret, dataset.output_types, dataset.output_shapes, + dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 67b085002aa7797d858837fea4646fb968ad5d97..0531f9cbb9da6e6df85fa46940ab1661ad742eb4 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,13 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops def group_by_window(key_func, @@ -35,7 +42,7 @@ def group_by_window(key_func, This transformation maps each consecutive element in a dataset to a key using `key_func` and groups the elements by key. It then applies `reduce_func` to at most `window_size_func(key)` elements matching the same - key. All execpt the final window for each key will contain + key. All except the final window for each key will contain `window_size_func(key)` elements; the final window may be smaller. You may provide either a constant `window_size` or a window size determined by @@ -85,6 +92,114 @@ def group_by_window(key_func, return _apply_fn +def bucket_by_sequence_length(element_length_func, + bucket_boundaries, + bucket_batch_sizes, + padded_shapes=None, + padding_values=None, + pad_to_bucket_boundary=False): + """A transformation that buckets elements in a `Dataset` by length. + + Elements of the `Dataset` are grouped together by length and then are padded + and batched. + + This is useful for sequence tasks in which the elements have variable length. + Grouping together elements that have similar lengths reduces the total + fraction of padding in a batch which increases training step efficiency. + + Args: + element_length_func: function from element in `Dataset` to `tf.int32`, + determines the length of the element, which will determine the bucket it + goes into. + bucket_boundaries: `list`, upper length boundaries of the buckets. + bucket_batch_sizes: `list`, batch size per bucket. Length should be + `len(bucket_boundaries) + 1`. + padded_shapes: Nested structure of `tf.TensorShape` to pass to + @{tf.data.Dataset.padded_batch}. If not provided, will use + `dataset.output_shapes`, which will result in variable length dimensions + being padded out to the maximum length in each batch. + padding_values: Values to pad with, passed to + @{tf.data.Dataset.padded_batch}. Defaults to padding with 0. + pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown + size to maximum length in batch. If `True`, will pad dimensions with + unknown size to bucket boundary, and caller must ensure that the source + `Dataset` does not contain any elements with length longer than + `max(bucket_boundaries)`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + + Raises: + ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. + """ + with ops.name_scope("bucket_by_seq_length"): + if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): + raise ValueError( + "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") + + batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) + + def element_to_bucket_id(*args): + """Return int64 id of the length bucket for this element.""" + seq_length = element_length_func(*args) + + boundaries = list(bucket_boundaries) + buckets_min = [np.iinfo(np.int32).min] + boundaries + buckets_max = boundaries + [np.iinfo(np.int32).max] + conditions_c = math_ops.logical_and( + math_ops.less_equal(buckets_min, seq_length), + math_ops.less(seq_length, buckets_max)) + bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) + + return bucket_id + + def window_size_fn(bucket_id): + # The window size is set to the batch size for this bucket + window_size = batch_sizes[bucket_id] + return window_size + + def make_padded_shapes(shapes, none_filler=None): + padded = [] + for shape in nest.flatten(shapes): + shape = tensor_shape.TensorShape(shape) + shape = [ + none_filler if d.value is None else d + for d in shape + ] + padded.append(shape) + return nest.pack_sequence_as(shapes, padded) + + def batching_fn(bucket_id, grouped_dataset): + """Batch elements in dataset.""" + batch_size = batch_sizes[bucket_id] + none_filler = None + if pad_to_bucket_boundary: + err_msg = ("When pad_to_bucket_boundary=True, elements must have " + "length <= max(bucket_boundaries).") + check = check_ops.assert_less( + bucket_id, + constant_op.constant(len(bucket_batch_sizes) - 1, + dtype=dtypes.int64), + message=err_msg) + with ops.control_dependencies([check]): + boundaries = constant_op.constant(bucket_boundaries, + dtype=dtypes.int64) + bucket_boundary = boundaries[bucket_id] + none_filler = bucket_boundary + shapes = make_padded_shapes( + padded_shapes or grouped_dataset.output_shapes, + none_filler=none_filler) + return grouped_dataset.padded_batch(batch_size, shapes, padding_values) + + def _apply_fn(dataset): + return dataset.apply( + group_by_window(element_to_bucket_id, batching_fn, + window_size_func=window_size_fn)) + + return _apply_fn + + class _VariantDataset(dataset_ops.Dataset): """A Dataset wrapper for a tf.variant-typed function argument.""" diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 3124ca1d1540e12d949dded88ce1c66181be3595..812a50ecbf105393f7e422edbbdf5c87311d72c1 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,101 +17,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib import stateless +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import convert +from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation -class ParallelInterleaveDataset(dataset_ops.Dataset): - """A `Dataset` that maps a function over its input and flattens the result.""" - - def __init__(self, input_dataset, map_func, cycle_length, block_length, - sloppy, buffer_output_elements, prefetch_input_elements): - """See `tf.contrib.data.parallel_interleave()` for details.""" - super(ParallelInterleaveDataset, self).__init__() - self._input_dataset = input_dataset - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_map_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access - dataset = map_func(*nested_args) - else: - dataset = map_func(nested_args) - - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`map_func` must return a `Dataset` object.") - - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - - return dataset._as_variant_tensor() # pylint: disable=protected-access - - self._map_func = tf_map_func - self._map_func.add_to_graph(ops.get_default_graph()) - - self._cycle_length = ops.convert_to_tensor( - cycle_length, dtype=dtypes.int64, name="cycle_length") - self._block_length = ops.convert_to_tensor( - block_length, dtype=dtypes.int64, name="block_length") - self._sloppy = ops.convert_to_tensor( - sloppy, dtype=dtypes.bool, name="sloppy") - self._buffer_output_elements = convert.optional_param_to_tensor( - "buffer_output_elements", - buffer_output_elements, - argument_default=2 * block_length) - self._prefetch_input_elements = convert.optional_param_to_tensor( - "prefetch_input_elements", - prefetch_input_elements, - argument_default=2 * cycle_length) - - def _as_variant_tensor(self): - return gen_dataset_ops.parallel_interleave_dataset( - self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - self._map_func.captured_inputs, - self._cycle_length, - self._block_length, - self._sloppy, - self._buffer_output_elements, - self._prefetch_input_elements, - f=self._map_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) - - @property - def output_classes(self): - return self._output_classes - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def parallel_interleave(map_func, cycle_length, block_length=1, @@ -162,7 +82,7 @@ def parallel_interleave(map_func, @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return ParallelInterleaveDataset( + return readers.ParallelInterleaveDataset( dataset, map_func, cycle_length, block_length, sloppy, buffer_output_elements, prefetch_input_elements) @@ -221,7 +141,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return ParallelInterleaveDataset( + return readers.ParallelInterleaveDataset( dataset, map_func, cycle_length, @@ -231,3 +151,92 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): prefetch_input_elements=None) return _apply_fn + + +class DirectedInterleaveDataset(dataset_ops.Dataset): + """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" + + def __init__(self, selector_input, data_inputs): + self._selector_input = selector_input + self._data_inputs = list(data_inputs) + + for data_input in data_inputs[1:]: + if (data_input.output_types != data_inputs[0].output_types or + data_input.output_classes != data_inputs[0].output_classes): + raise TypeError("All datasets must have the same type.") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_dataset_ops.directed_interleave_dataset( + self._selector_input._as_variant_tensor(), + [data_input._as_variant_tensor() for data_input in self._data_inputs], + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._data_inputs[0].output_classes + + @property + def output_shapes(self): + ret = self._data_inputs[0].output_shapes + for data_input in self._data_inputs[1:]: + ret = nest.pack_sequence_as(ret, [ + ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( + nest.flatten(ret), nest.flatten(data_input.output_shapes)) + ]) + return ret + + @property + def output_types(self): + return self._data_inputs[0].output_types + + +def sample_from_datasets(datasets, weights=None, seed=None): + """Samples elements at random from the datasets in `datasets`. + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + weights: (Optional.) A list of `len(datasets)` floating-point values where + `weights[i]` represents the probability with which an element should be + sampled from `datasets[i]`, or a @{tf.data.Dataset} object where each + element is such a list. Defaults to a uniform distribution across + `datasets`. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A dataset that interleaves elements from `datasets` at random, according to + `weights` if provided, otherwise with uniform probability. + + Raises: + TypeError: If the `datasets` or `weights` arguments have the wrong type. + ValueError: If the `weights` argument is specified and does not match the + length of the `datasets` element. + """ + num_datasets = len(datasets) + if weights is None: + weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat() + elif not isinstance(weights, dataset_ops.Dataset): + weights = ops.convert_to_tensor(weights, name="weights") + if weights.dtype not in (dtypes.float32, dtypes.float64): + raise TypeError("`weights` must be convertible to a tensor of " + "`tf.float32` or `tf.float64` elements.") + if not weights.shape.is_compatible_with([num_datasets]): + raise ValueError("`weights` must be a vector of length `len(datasets)`.") + weights = dataset_ops.Dataset.from_tensors(weights).repeat() + + # The `stateless_multinomial()` op expects log-probabilities, as opposed to + # weights. + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + def select_dataset(logits, seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + selector_input = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) + + return DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 96a9e9ed6649444dac5e56d7dd2fcdb62fc56459..e4c9f8b58a2a4390004b0ad318163526b443d44f 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,27 +17,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops # TODO(rohanj): Add a python class that constructs resource in the __init__ # method and provides a get_next() that calls the prefetch op. def function_buffering_resource(string_arg, target_device, - shared_name, f, buffer_size, - thread_pool_size=1, container="", + shared_name=None, name=None): + if shared_name is None: + shared_name = "" return gen_dataset_ops.function_buffering_resource( string_arg=string_arg, target_device=target_device, shared_name=shared_name, f=f, buffer_size=buffer_size, - thread_pool_size=thread_pool_size, container=container, name=name) @@ -49,3 +60,266 @@ def function_buffering_resource_get_next(function_buffer_resource, function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) + + +def function_buffering_resource_reset(function_buffer_resource, name=None): + return gen_dataset_ops.function_buffering_resource_reset( + function_buffer_resource=function_buffer_resource, name=name) + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + device, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + iterator_device = gen_dataset_ops.iterator_get_device( + self._input_iterator._iterator_resource) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=iterator_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name) + + if not self._one_shot: + reset_op = function_buffering_resource_reset(self._buffering_resource) + with ops.control_dependencies([reset_op]): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + self._buffering_resource, + output_types=nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + nest.flatten(ret), nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + + return ret + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): + """A replacement for @{tf.data.Iterator} that prefetches to another device. + + Args: + input_dataset: The input dataset + one_shot: If true, we make a one shot iterator that's already initialized. + device: A fully specified device string where we want to prefetch to + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + device, + buffer_size): + with ops.device("/device:CPU:0"): + super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset) + input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle( + self._resource) + + self._device = device + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self.output_types, self.output_shapes, self.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + _prefetch_fn.add_to_graph(None) + + with ops.device(device): + self._buffering_resource = function_buffering_resource( + f=_prefetch_fn, + target_device=gen_dataset_ops.iterator_get_device(self._resource), + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=iterator_ops._generate_shared_name( + "function_buffer_resource")) + + def _next_internal(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): + with ops.device(self._device): + ret = gen_dataset_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffering_resource, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to another device.""" + + def __init__(self, input_dataset, device, buffer_size): + self._input_dataset = input_dataset + self._device = device + self._buffer_size = buffer_size if buffer_size is not None else 1 + + # The static analysis cannot tell that the eager iterator's superclass has + # a `next()` method. + # pylint: disable=non-iterator-returned + def __iter__(self): + """Creates an `Iterator` for enumerating the elements of this dataset. + + The returned iterator implements the Python iterator protocol and therefore + can only be used in eager mode. + + Returns: + An `Iterator` over the elements of this dataset. + + Raises: + RuntimeError: If eager execution is enabled. + """ + if context.executing_eagerly(): + return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, + self._buffer_size) + else: + raise RuntimeError("dataset.__iter__() is only supported when eager " + "execution is enabled.") + # pylint: enable=non-iterator-returned + + def make_one_shot_iterator(self): + if context.executing_eagerly(): + return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device, + self._buffer_size) + else: + return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True, + device=self._device, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + device=self._device, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_device()` must be the last " + "transformation in a dataset pipeline.") + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_device(device, buffer_size=None): + """A transformation that prefetches dataset values to the given `device`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + device: A string. The name of a device to which elements will be prefetched. + buffer_size: (Optional.) The number of elements to buffer on `device`. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, device, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 7d727165feabb101549567f28a2dfa07083de244..28ef5e50f39dd7d1b6f124e58e068fc968ddd6dc 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -19,11 +19,10 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops @@ -34,16 +33,7 @@ class RandomDataset(dataset_ops.Dataset): def __init__(self, seed=None): """A `Dataset` of pseudorandom values.""" super(RandomDataset, self).__init__() - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) def _as_variant_tensor(self): return gen_dataset_ops.random_dataset( diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 57f30102778f3bac47580f9bdf94e411dfe1b621..4ec8ae1c79d1eb99c56b31c6a0709a84c38f5f90 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,20 +17,551 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import csv +from math import ceil + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation + +_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32, + dtypes.int64, dtypes.string) + + +def _is_valid_int32(str_val): + try: + # Checks equality to prevent int32 overflow + return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype( + str_val) + except (ValueError, OverflowError): + return False + + +def _is_valid_int64(str_val): + try: + dtypes.int64.as_numpy_dtype(str_val) + return True + except (ValueError, OverflowError): + return False + + +def _is_valid_float(str_val, float_dtype): + try: + return float_dtype.as_numpy_dtype(str_val) < np.inf + except ValueError: + return False + + +def _infer_type(str_val, na_value, prev_type, float_dtype): + """Given a string, infers its tensor type. + + Infers the type of a value by picking the least 'permissive' type possible, + while still allowing the previous type inference for this column to be valid. + + Args: + str_val: String value to infer the type of. + na_value: Additional string to recognize as a NA/NaN CSV value. + prev_type: Type previously inferred based on values of this column that + we've seen up till now. + float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type + to parse float strings as. + Returns: + Inferred dtype. + """ + if str_val in ("", na_value): + return prev_type + + if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): + return dtypes.int32 + + if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, + dtypes.int64): + return dtypes.int64 + + if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: + return float_dtype + + return dtypes.string + + +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, + comment): + for fn in filenames: + with file_io.FileIO(fn, "r") as f: + rdr = csv.reader( + f, + delimiter=field_delim, + quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE) + if header: + next(rdr) # Skip header lines + + for csv_row in rdr: + if comment is not None and csv_row[0].startswith(comment): + continue # Skip comment lines + + if len(csv_row) != num_cols: + raise ValueError( + "Problem inferring types: CSV row has different number of fields " + "than expected.") + yield csv_row + + +def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, + na_value, header, comment, float_dtype, + num_rows_for_inference, select_columns): + """Infers column types from the first N valid CSV records of files.""" + if select_columns is None: + select_columns = range(num_cols) + inferred_types = [None] * len(select_columns) + + for i, csv_row in enumerate( + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, + comment)): + if num_rows_for_inference is not None and i >= num_rows_for_inference: + break + + for j, col_index in enumerate(select_columns): + inferred_types[j] = _infer_type(csv_row[col_index], na_value, + inferred_types[j], float_dtype) + + # Replace None's with a default type + inferred_types = [t or dtypes.string for t in inferred_types] + # Default to 0 or '' for null values + return [ + constant_op.constant([0 if t is not dtypes.string else ""], dtype=t) + for t in inferred_types + ] + + +def _infer_column_names(filenames, field_delim, use_quote_delim): + """Infers column names from first rows of files.""" + csv_kwargs = { + "delimiter": field_delim, + "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE + } + with file_io.FileIO(filenames[0], "r") as f: + column_names = next(csv.reader(f, **csv_kwargs)) + + for name in filenames[1:]: + with file_io.FileIO(name, "r") as f: + if next(csv.reader(f, **csv_kwargs)) != column_names: + raise ValueError("Files have different column names in the header row.") + return column_names + + +def _get_sorted_col_indices(select_columns, column_names): + """Transforms select_columns argument into sorted column indices.""" + names_to_indices = {n: i for i, n in enumerate(column_names)} + num_cols = len(column_names) + for i, v in enumerate(select_columns): + if isinstance(v, int): + if v < 0 or v >= num_cols: + raise ValueError( + "Column index %d specified in select_columns out of valid range." % + v) + continue + if v not in names_to_indices: + raise ValueError( + "Value '%s' specified in select_columns not a valid column index or " + "name." % v) + select_columns[i] = names_to_indices[v] + + # Sort and ensure there are no duplicates + result = sorted(set(select_columns)) + if len(result) != len(select_columns): + raise ValueError("select_columns contains duplicate columns") + return result + + +def make_csv_dataset( + file_pattern, + batch_size, + column_names=None, + column_defaults=None, + label_name=None, + select_columns=None, + field_delim=",", + use_quote_delim=True, + na_value="", + header=True, + comment=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + num_parallel_reads=1, + num_parallel_parser_calls=2, + sloppy=False, + default_float_type=dtypes.float32, + num_rows_for_inference=100, +): + """Reads CSV files into a dataset. + + Reads CSV files into a dataset, where each element is a (features, labels) + tuple that corresponds to a batch of CSV rows. The features dictionary + maps feature column names to `Tensor`s containing the corresponding + feature data, and labels is a `Tensor` containing the batch's label data. + + Args: + file_pattern: List of files or patterns of file paths containing CSV + records. See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of consecutive elements of this + dataset to combine in a single batch. + column_names: An optional list of strings that corresponds to the CSV + columns, in order. One per column of the input record. If this is not + provided, infers the column names from the first row of the records. + These names will be the keys of the features dict of each dataset element. + column_defaults: A optional list of default values for the CSV fields. One + item per selected column of the input record. Each item in the list is + either a valid CSV dtype (float32, float64, int32, int64, or string), or a + `Tensor` with one of the aforementioned types. The tensor can either be + a scalar default value (if the column is optional), or an empty tensor (if + the column is required). If a dtype is provided instead of a tensor, the + column is also treated as required. If this list is not provided, tries + to infer types based on reading the first num_rows_for_inference rows of + files specified, and assumes all columns are optional, defaulting to `0` + for numeric values and `""` for string values. If both this and + `select_columns` are specified, these must have the same lengths, and + `column_defaults` is assumed to be sorted in order of increasing column + index. + label_name: A optional string corresponding to the label column. If + provided, the data for this column is returned as a separate `Tensor` from + the features dictionary, so that the dataset complies with the format + expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input + function. + select_columns: An optional list of integer indices or string column + names, that specifies a subset of columns of CSV data to select. If + column names are provided, these must correspond to names provided in + `column_names` or inferred from the file header lines. When this argument + is specified, only a subset of CSV columns will be parsed and returned, + corresponding to the columns specified. Using this results in faster + parsing and lower memory usage. If both this and `column_defaults` are + specified, these must have the same lengths, and `column_defaults` is + assumed to be sorted in order of increasing column index. + field_delim: An optional `string`. Defaults to `","`. Char delimiter to + separate fields in a record. + use_quote_delim: An optional bool. Defaults to `True`. If false, treats + double quotation marks as regular characters inside of the string fields. + na_value: Additional string to recognize as NA/NaN. + header: A bool that indicates whether the first rows of provided CSV files + correspond to header lines with column names, and should not be included + in the data. + comment: An optional character string that marks lines that should not be + parsed as csv records. If this is provided, all lines that start with + this character will not be parsed. + num_epochs: An int specifying the number of times this dataset is repeated. + If None, cycles through the dataset forever. + shuffle: A bool that indicates whether the input should be shuffled. + shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size + ensures better shuffling, but would increase memory usage and startup + time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: An int specifying the number of feature batches to + prefetch for performance improvement. Recommended value is the number of + batches consumed per training step. + num_parallel_reads: Number of threads used to read CSV records from files. + If >1, the results will be interleaved. + num_parallel_parser_calls: Number of parallel invocations of the CSV parsing + function on CSV records. + sloppy: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + default_float_type: Either `tf.float32` or `tf.float64`. If defaults are + not provided, float-like strings are interpreted to be this type. + num_rows_for_inference: Number of rows of a file to use for type inference + if record_defaults is not provided. If None, reads all the rows of all + the files. Defaults to 100. + + Returns: + A dataset, where each element is a (features, labels) tuple that corresponds + to a batch of `batch_size` CSV rows. The features dictionary maps feature + column names to `Tensor`s containing the corresponding column data, and + labels is a `Tensor` containing the column data for the label column + specified by `label_name`. + + Raises: + ValueError: If any of the arguments is malformed. + """ + # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Clean arguments; figure out column names and defaults + if comment is not None and len(comment) != 1: + raise ValueError("`comment` arg must be a single-character string or None") + + if column_names is None: + if not header: + raise ValueError("Cannot infer column names without a header line.") + # If column names are not provided, infer from the header lines + column_names = _infer_column_names(filenames, field_delim, use_quote_delim) + if len(column_names) != len(set(column_names)): + raise ValueError("Cannot have duplicate column names.") + + if select_columns is not None: + select_columns = _get_sorted_col_indices(select_columns, column_names) + + if column_defaults is not None: + column_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in column_defaults + ] + else: + # If column defaults are not provided, infer from records at graph + # construction time + column_defaults = _infer_column_defaults( + filenames, len(column_names), field_delim, use_quote_delim, na_value, + header, comment, default_float_type, num_rows_for_inference, + select_columns) + + if select_columns is not None and len(column_defaults) != len(select_columns): + raise ValueError( + "If specified, column_defaults and select_columns must have same " + "length." + ) + if select_columns is not None and len(column_names) > len(select_columns): + # Pick the relevant subset of column names + column_names = [column_names[i] for i in select_columns] + + if label_name is not None and label_name not in column_names: + raise ValueError("`label_name` provided must be one of the columns.") + + # Define map and filter functions + def filter_fn(line): + return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) + + def filename_to_dataset(filename): + ds = core_readers.TextLineDataset(filename) + if header: + ds = ds.skip(1) + if comment is not None: + ds = ds.filter(filter_fn) + return ds + + def decode_csv(line): + """Decodes CSV line into features. + + Args: + line: String tensor corresponding to one csv record. + Returns: + A dictionary of feature names to values for that particular record. If + label_name is provided, extracts the label feature to be returned as the + second element of the tuple. + """ + columns = parsing_ops.decode_csv( + line, + column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + ) + features = dict(zip(column_names, columns)) + if label_name is not None: + label = features.pop(label_name) + return features, label + return features + + # Read files sequentially or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) + + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + dataset = dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + dataset = dataset.repeat(num_epochs) + + # Use map_and_batch for perf + # TODO(b/76425672): use num_parallel_calls for better performance tuning when + # that is added + dataset = dataset.apply( + batching.map_and_batch( + map_func=decode_csv, + batch_size=batch_size, + num_parallel_batches=int( + ceil(num_parallel_parser_calls / batch_size)))) + + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +def make_batched_features_dataset(file_pattern, + batch_size, + features, + reader=core_readers.TFRecordDataset, + reader_args=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + reader_num_threads=1, + parser_num_threads=2, + sloppy_ordering=False, + drop_final_batch=False): + """Returns a `Dataset` of feature dictionaries from `Example` protos. + Example: + + ``` + serialized_examples = [ + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } + }, + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } + } + ] + ``` + + We can use arguments: + ``` + features: { + "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), + "gender": FixedLenFeature([], dtype=tf.string), + "kws": VarLenFeature(dtype=tf.string), + } + ``` + + And the expected output is: + + ```python + { + "age": [[0], [-1]], + "gender": [["f"], ["f"]], + "kws": SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["code", "art", "sports"] + dense_shape=[2, 2]), + } + ``` + + Args: + file_pattern: List of files or patterns of file paths containing + `Example` records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of consecutive elements of this + dataset to combine in a single batch. + features: A `dict` mapping feature keys to `FixedLenFeature` or + `VarLenFeature` values. See `tf.parse_example`. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. + reader_args: Additional arguments to pass to the reader class. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. Defaults to `None`. + shuffle: A boolean, indicates whether the input should be shuffled. Defaults + to `True`. + shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity + ensures better shuffling but would increase memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: Number of feature batches to prefetch in order to + improve performance. Recommended value is the number of batches consumed + per training step (default is 1). + reader_num_threads: Number of threads used to read `Example` records. If >1, + the results will be interleaved. + parser_num_threads: Number of threads to use for parsing `Example` tensors + into a dictionary of `Feature` tensors. + sloppy_ordering: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + drop_final_batch: If `True`, and the batch size does not evenly divide the + input dataset size, the final smaller batch will be dropped. Defaults to + `False`. + + Returns: + A dataset of `dict` elements. Each `dict` maps feature keys to + `Tensor` or `SparseTensor` objects. + """ + # Create dataset of all matching filenames + filenames = _get_file_names(file_pattern, False) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + if shuffle: + dataset = dataset.shuffle(len(filenames), shuffle_seed) + + # Read `Example` records from files as tensor objects. + if reader_args is None: + reader_args = [] + + # Read files sequentially (if reader_num_threads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + lambda filename: reader(filename, *reader_args), + cycle_length=reader_num_threads, + sloppy=sloppy_ordering)) + + # Extract values if the `Example` tensors are stored as key-value tuples. + if dataset.output_types == (dtypes.string, dtypes.string): + dataset = dataset.map(lambda _, v: v) + + # Apply dataset repeat and shuffle transformations. + repeat_dataset = (num_epochs != 1) + if repeat_dataset and shuffle: + # Used fused shuffle_and_repeat operation for better performance + dataset = dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif repeat_dataset: + dataset = dataset.repeat(num_epochs) + elif shuffle: + dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + + # Parse `Example` tensors to a dictionary of `Feature` tensors. + dataset = dataset.map( + lambda x: parsing_ops.parse_example(x, features), + num_parallel_calls=parser_num_threads) + + # TODO(rachelim): Add an optional label_name argument for extracting the label + # from the features dictionary, to comply with the type expected by the + # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function. + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +@deprecation.deprecated(None, + "Use `tf.contrib.data.make_batched_features_dataset`") def read_batch_features(file_pattern, batch_size, features, - reader, + reader=core_readers.TFRecordDataset, reader_args=None, randomize_input=True, num_epochs=None, @@ -84,43 +615,38 @@ def read_batch_features(file_pattern, dataset to combine in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. - reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of Examples. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. - capacity: Capacity of the ShuffleDataset. A large capacity ensures better + capacity: Buffer size of the ShuffleDataset. A large capacity ensures better shuffling but would increase memory usage and startup time. - Returns: A dict from keys in features to `Tensor` or `SparseTensor` objects. """ - filenames = _get_file_names(file_pattern, randomize_input) - if reader_args: - dataset = reader(filenames, *reader_args) - else: - dataset = reader(filenames) - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) - if num_epochs != 1: - dataset = dataset.repeat(num_epochs) - if randomize_input: - dataset = dataset.shuffle(capacity) - dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) - dataset = dataset.prefetch(1) + dataset = make_batched_features_dataset( + file_pattern, + batch_size, + features, + reader=reader, + reader_args=reader_args, + shuffle=randomize_input, + num_epochs=num_epochs, + shuffle_buffer_size=capacity) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs -def _get_file_names(file_pattern, randomize_input): +def _get_file_names(file_pattern, shuffle): """Parse list of file names from pattern, optionally shuffled. Args: file_pattern: File glob pattern, or list of glob patterns. - randomize_input: Whether to shuffle the order of file names. + shuffle: Whether to shuffle the order of file names. Returns: List of file names matching `file_pattern`. @@ -141,7 +667,7 @@ def _get_file_names(file_pattern, randomize_input): raise ValueError("No files match %s." % file_pattern) # Sort files so it will be deterministic for unit tests. - if not randomize_input: + if not shuffle: file_names = sorted(file_names) return file_names diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index 56f526a330bfbea7305b0754bfd114c5e97db506..b465397437adbdfaf865efb8ed2f80e57f48fcab 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -54,7 +54,7 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" dist_estimation_batch_size = 32 - target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") + target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") class_values_ds = dataset.map(class_func) if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") @@ -101,14 +101,16 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): initial_dist_ds)) .map(maybe_warn_on_large_rejection)) - current_probabilities_ds = dataset_ops.Dataset.zip( - (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) + def _gather_and_copy(class_val, acceptance_prob, data): + return (class_val, array_ops.gather(acceptance_prob, class_val), data) + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) filtered_ds = ( - dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds, - dataset)) + current_probabilities_and_class_and_data_ds .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + return _apply_fn @@ -151,7 +153,7 @@ def _calculate_acceptance_probs(initial_probs, target_probs): ``` - A solution for a_i in terms of the other variabes is the following: + A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` """ # Add tiny to initial_probs to avoid divide by zero. diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 1c88366273f5d186509454188e02350d4ea9f66b..fe49ee8b1946c03dca37c2bc17dde71646b85bbd 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -57,7 +57,7 @@ class _ScanDataset(dataset_ops.Dataset): self._output_shapes = None self._output_types = None - # Iteratively rerun the scan function until reaching a fixed pont on + # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index 99bb79bc06a421f811869ca9169aaa11deaca2f3..f35795abd38000b13cec0f08596e2ff66e86286c 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -19,11 +19,11 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed from tensorflow.python.ops import gen_dataset_ops @@ -45,17 +45,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): else: self._count = ops.convert_to_tensor( count, dtype=dtypes.int64, name="count") - - seed, seed2 = random_seed.get_seed(seed) - if seed is None: - self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") - else: - self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") - if seed2 is None: - self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") - else: - self._seed2 = ops.convert_to_tensor( - seed2, dtype=dtypes.int64, name="seed2") + self._seed, self._seed2 = random_seed.get_seed(seed) def _as_variant_tensor(self): # pylint: disable=protected-access diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py new file mode 100644 index 0000000000000000000000000000000000000000..19cc3cb89fc5c494f79ce1d25ed57c92099c8bd2 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sliding dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class _SlideDataset(dataset_ops.Dataset): + """A `Dataset` that passes a sliding window over its input.""" + + def __init__(self, input_dataset, window_size, stride=1): + """See `sliding_window_batch` for details.""" + super(_SlideDataset, self).__init__() + self._input_dataset = input_dataset + self._window_size = ops.convert_to_tensor( + window_size, dtype=dtypes.int64, name="window_size") + self._stride = ops.convert_to_tensor( + stride, dtype=dtypes.int64, name="stride") + + def _as_variant_tensor(self): + return gen_dataset_ops.slide_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + window_size=self._window_size, + stride=self._stride, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + input_shapes = self._input_dataset.output_shapes + return nest.pack_sequence_as(input_shapes, [ + tensor_shape.vector(None).concatenate(s) + for s in nest.flatten(self._input_dataset.output_shapes) + ]) + + @property + def output_types(self): + return self._input_dataset.output_types + + +def sliding_window_batch(window_size, stride=1): + """A sliding window with size of `window_size` and step of `stride`. + + This transformation passes a sliding window over this dataset. The + window size is `window_size` and step size is `stride`. If the left + elements cannot fill up the sliding window, this transformation will + drop the final smaller element. For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { [1], [2], [3], [4], [5], [6] } + + a.apply(tf.contrib.data.sliding_window_batch(window_size=3, stride=2)) == + { + [[1], [2], [3]], + [[3], [4], [5]], + } + ``` + + Args: + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + elements in the sliding window. + stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + steps moving the sliding window forward for one iteration. The default + is `1`. It must be in `[1, window_size)`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _SlideDataset(dataset, window_size, stride) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 9cd1701c397b5a0bf5cc47c1bcab033704794d80..b5cf0fcfe91ebc22444302fca5d488a278ef2994 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -47,7 +47,7 @@ class StatsAggregator(object): dataset = ... iterator = dataset.make_one_shot_iterator() stats_aggregator = stats_ops.StatsAggregator() - set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) + set_op = stats_aggregator.subscribe(iterator) with tf.Session() as sess: # Running `set_op` will associate `iterator` with `stats_aggregator`. diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py new file mode 100644 index 0000000000000000000000000000000000000000..56f67e1766bbaff680bdff6b939df0c3ba68c679 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for controlling threading in `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.ops import resource_variable_ops + +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def _generate_shared_name(prefix): + with _uid_lock: + global _uid_counter + uid = _uid_counter + _uid_counter += 1 + return "{}{}".format(prefix, uid) + + +class PrivateThreadPool(object): + """A stateful resource that represents a private thread pool.""" + + def __init__(self, num_threads, display_name=None): + """Creates a `PrivateThreadPool` with the given number of threads.""" + if context.executing_eagerly(): + shared_name = _generate_shared_name("privatethreadpool") + self._resource = gen_dataset_ops.thread_pool_handle( + num_threads=num_threads, + display_name=display_name, + shared_name=shared_name) + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=context.context().device_name) + else: + self._resource = gen_dataset_ops.thread_pool_handle( + num_threads=num_threads, display_name=display_name) + + +class _ThreadPoolDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and sets a custom threadpool.""" + + def __init__(self, input_dataset, thread_pool): + super(_ThreadPoolDataset, self).__init__() + self._input_dataset = input_dataset + self._thread_pool = thread_pool + + def _as_variant_tensor(self): + return gen_dataset_ops.thread_pool_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._thread_pool._resource, # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def override_threadpool(dataset, thread_pool): + """Returns a new dataset that uses the given thread pool for its operations. + + Args: + dataset: A `tf.data.Dataset` object. + thread_pool: A `PrivateThreadPool` object. + + Returns: + A dataset containing the same values as `dataset`, but which uses + `thread_pool` to compute any of its parallel operations (such as + @{tf.data.Dataset.map}). + """ + return _ThreadPoolDataset(dataset, thread_pool) diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 133e17d20d0fc4c8d52cef3c95c132374e927a0b..765ef3f9b6d42c9d7af3ce4916731d37d65c9260 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -17,11 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes -from tensorflow.python.ops import gen_dataset_ops def unique(): diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index f6de5998d73a4869d2444cd90c9b64d1a2c889ac..3b50a48336d77ebd9327fa24e5612a95d5d0c372 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -13,19 +13,10 @@ load( "tf_pyclif_proto_library", ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], cc_api_version = 2, - go_api_version = 2, java_api_version = 2, visibility = ["//visibility:public"], ) @@ -34,7 +25,6 @@ tf_proto_library( name = "generic_tree_model_extensions", srcs = ["generic_tree_model_extensions.proto"], cc_api_version = 2, - go_api_version = 2, protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/deprecated/BUILD b/tensorflow/contrib/deprecated/BUILD index 3dfbbf55273848afb8ad74ad444f0d85b45610bd..401527f1e74f7725d02a3b92a2c661d8ffc11e21 100644 --- a/tensorflow/contrib/deprecated/BUILD +++ b/tensorflow/contrib/deprecated/BUILD @@ -30,15 +30,3 @@ py_test( "//tensorflow/python:logging_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..74b2cd90a187159fd2da8ce236c14e813cc43c49 --- /dev/null +++ b/tensorflow/contrib/distribute/BUILD @@ -0,0 +1,36 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "distribute", + srcs = ["__init__.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/distribute/python:cross_tower_ops", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:monitor", + "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/distribute/python:step_fn", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d22d9aa2bb88337e9b740d297cdf87683bdd578 --- /dev/null +++ b/tensorflow/contrib/distribute/README.md @@ -0,0 +1,143 @@ +# Distribution Strategy + +> *NOTE*: This is a experimental feature. The API and performance +> characteristics are subject to change. + +## Overview + +[`DistributionStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/DistributionStrategy) +API is an easy way to distribute your training +across multiple devices/machines. Our goal is to allow users to use existing +models and training code with minimal changes to enable distributed training. +Moreover, we've design the API in such a way that it works with both eager and +graph execution. + +Currently we support one type of strategy, called +[`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy). +It does in-graph replication with synchronous training +on many GPUs on one machine. Essentially, we create copies of all variables in +the model's layers on each device. We then use all-reduce to combine gradients +across the devices before applying them to the variables to keep them in sync. +In the future, we intend to support other kinds of training configurations such +as multi-node, synchronous, +[asynchronous](https://www.tensorflow.org/deploy/distributed#putting_it_all_together_example_trainer_program), +parameter servers and model parallelism. + +## Example + +Let's demonstrate how to use this API with a simple example. We will use the +[`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) +approach, and show you how to scale your model to run on multiple GPUs on one +machine using `MirroredStrategy`. + +Let's consider a very simple model function which tries to learn a simple +function. + +```python +def model_fn(features, labels, mode): + layer = tf.layers.Dense(1) + logits = layer(features) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"logits": logits} + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + loss = tf.losses.mean_squared_error( + labels=labels, predictions=tf.reshape(logits, [])) + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec(mode, loss=loss) + + if mode == tf.estimator.ModeKeys.TRAIN: + train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn()) + return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) +``` + +Let's also define a simple input function to feed data for training this model. +Note that we require using +[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) +with `DistributionStrategy`. + + +```python +def input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(100) + labels = tf.data.Dataset.from_tensors(1.).repeat(100) + return dataset_ops.Dataset.zip((features, labels)) +``` + +Now that we have a model function and input function defined, we can define the +estimator. To use `MirroredStrategy`, all we need to do is: + +* Create an instance of the `MirroredStrategy` class. +* Pass it to the +[`RunConfig`](https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig) +parameter of `Estimator`. + + +```python +distribution = tf.contrib.distribute.MirroredStrategy() +config = tf.estimator.RunConfig(train_distribute=distribution) +classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) +classifier.train(input_fn=input_fn) +``` + +That's it! This change will now configure estimator to run on all GPUs on your +machine, with the `MirroredStrategy` approach. It will take care of distributing +the input dataset, replicating layers and variables on each device, and +combining and applying gradients. + +The model and input functions do not have to change because we have changed the +underlying components of TensorFlow (such as +optimizer, batch norm and summaries) to become distribution-aware. +That means those components know how to +combine their state across devices. Further, saving and checkpointing works +seamlessly, so you can save with one or no distribution strategy and resume with +another. + +Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__). +There are few things you can customize in practice: + +* You can specify a list of specific GPUs (using param `devices`) or the number +of GPUs (using param `num_gpus`), in case you don't want auto detection. +* You can specify various parameters for all reduce with the `cross_tower_ops` +param, such as the all reduce algorithm to use, and gradient repacking. + +## Performance Tips + +We've tried to make it such that you get the best performance for your existing +model. We also recommend you follow the tips from +[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance). +Specifically, we found using [`map_and_batch`](https://www.tensorflow.org/performance/datasets_performance#map_and_batch) +and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_performance#pipelining) +in the input function gives a solid boost in performance. When using +`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size. + +## Caveats +This feature is in early stages and there are a lot of improvements forthcoming: + +* Metrics are not yet supported during distributed training. They are still +supported during the evaluation. +* Summaries are only computed in the first tower in `MirroredStrategy`. +* Evaluation is not yet distributed. +* Eager support is in the works; performance can be more challenging with eager +execution. +* As mentioned earlier, multi-node and other distributed strategies will be +introduced in the future. +* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch) +your input data, we will place one batch on each GPU in each step. So your +effective batch size will be `num_gpus * batch_size`. Therefore, consider +adjusting your learning rate or batch size according to the number of GPUs. +We are working on addressing this limitation by splitting each batch across GPUs +instead. +* PartitionedVariables are not supported yet. +* Input pipelines with Datasets that capture stateful objects and rely on +`make_initializable_iterator` are not supported yet. + +## What's next? + +Please give distribution strategies a try. This feature is in early stages and +is evolving, so we welcome your feedback via +[issues on GitHub](https://github.com/tensorflow/tensorflow/issues/new). + + diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76711baf3a11c8978fbb5770ec173ff74a153158 --- /dev/null +++ b/tensorflow/contrib/distribute/__init__.py @@ -0,0 +1,52 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Prototype of a distributed computation library for TF.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.distribute.python.cross_tower_ops import * +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.contrib.distribute.python.monitor import Monitor +from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy +from tensorflow.contrib.distribute.python.step_fn import * +from tensorflow.python.training.distribute import * + +from tensorflow.python.util.all_util import remove_undocumented + + +_allowed_symbols = [ + 'AllReduceCrossTowerOps', + 'CrossTowerOps', + 'DistributionStrategy', + 'MirroredStrategy', + 'Monitor', + 'OneDeviceStrategy', + 'ReductionToOneDeviceCrossTowerOps', + 'Step', + 'StandardInputStep', + 'StandardSingleLossStep', + 'TowerContext', + 'get_cross_tower_context', + 'get_distribution_strategy', + 'get_loss_reduction', + 'get_tower_context', + 'has_distribution_strategy', + 'require_tower_context', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..837a1f134800b0f9faa72f5907d3398964294469 --- /dev/null +++ b/tensorflow/contrib/distribute/python/BUILD @@ -0,0 +1,477 @@ +# Implementation of a prototype TF distributed computation library. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# TODO(priyag): Figure out testonly issues that are preventing us from +# including our tests in pip for now. + +py_library( + name = "values", + srcs = ["values.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":prefetching_ops_v2", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:checkpointable", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device_util", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "values_test", + srcs = ["values_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python:errors", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python:device_util", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "mirrored_strategy", + srcs = ["mirrored_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":cross_tower_ops", + ":shared_variable_creator", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_util", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + "@six_archive//:six", + ], +) + +py_library( + name = "one_device_strategy", + srcs = ["one_device_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":values", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/python:array_ops", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_library( + name = "strategy_test_lib", + testonly = 1, + srcs = ["strategy_test_lib.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_ops", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "combinations", + testonly = 1, + srcs = ["combinations.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":one_device_strategy", + ":tpu_strategy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "combinations_test", + srcs = ["combinations_test.py"], + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "mirrored_strategy_test", + srcs = ["mirrored_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":mirrored_strategy", + ":strategy_test_lib", + "//tensorflow/python:distribute", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_test( + name = "one_device_strategy_test", + srcs = ["one_device_strategy_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":one_device_strategy", + ":strategy_test_lib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:test", + ], +) + +cuda_py_test( + name = "mirrored_strategy_multigpu_test", + srcs = ["mirrored_strategy_multigpu_test.py"], + additional_deps = [ + ":mirrored_strategy", + ":values", + ":strategy_test_lib", + "//tensorflow/python:distribute", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:variable_scope", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "guitar", + "no_pip", + "multi_and_single_gpu", + # Do not perform the extra analysis on this test, because it is already + # performed for the `:mirrored_strategy_test` target. + "no_oss", + "noasan", + "notap", + "notsan", + ], +) + +py_library( + name = "step_fn", + srcs = ["step_fn.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:training", + "//tensorflow/python/eager:backprop", + ], +) + +py_library( + name = "tpu_strategy", + srcs = ["tpu_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/distribute/python:one_device_strategy", + "//tensorflow/contrib/eager/python:datasets", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/contrib/tpu", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_library( + name = "minimize_loss_test_lib", + testonly = 1, + srcs = ["minimize_loss_test.py"], + deps = [ + ":combinations", + ":single_loss_example", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "minimize_loss_test", + srcs = ["minimize_loss_test.py"], + additional_deps = [ + ":minimize_loss_test_lib", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "optimizer_v2_test", + srcs = ["optimizer_v2_test.py"], + additional_deps = [ + ":combinations", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +cuda_py_test( + name = "estimator_integration_test", + srcs = ["estimator_integration_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:dnn_linear_combined", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "single_loss_example", + srcs = ["single_loss_example.py"], + deps = [ + ":step_fn", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +cuda_py_test( + name = "step_fn_test", + srcs = ["step_fn_test.py"], + additional_deps = [ + ":single_loss_example", + ":combinations", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "monitor", + srcs = ["monitor.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "monitor_test", + srcs = ["monitor_test.py"], + additional_deps = [ + ":combinations", + ":monitor", + ":one_device_strategy", + ":single_loss_example", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) + +py_library( + name = "shared_variable_creator", + srcs = ["shared_variable_creator.py"], + visibility = ["//tensorflow:internal"], +) + +py_test( + name = "shared_variable_creator_test", + srcs = ["shared_variable_creator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":shared_variable_creator", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "cross_tower_utils", + srcs = ["cross_tower_utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/nccl:nccl_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +py_library( + name = "cross_tower_ops", + srcs = ["cross_tower_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":cross_tower_utils", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "@six_archive//:six", + ], +) + +py_test( + name = "cross_tower_ops_test", + srcs = ["cross_tower_ops_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + ":combinations", + ":cross_tower_ops", + ":values", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "prefetching_ops_v2", + srcs = ["prefetching_ops_v2.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:contrib_op_loader", + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +cuda_py_test( + name = "prefetching_ops_v2_test", + srcs = ["prefetching_ops_v2_test.py"], + additional_deps = [ + ":prefetching_ops_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py new file mode 100644 index 0000000000000000000000000000000000000000..1f66997e6eca0b72c2d6eb52ace38cc568ef5d42 --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -0,0 +1,312 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Facilities for creating multiple test combinations. + +Here is an example of testing various optimizers in Eager and Graph mode: + +class AdditionExample(test.TestCase, parameterized.TestCase): + @combinations.generate( + combinations.combine(mode=["graph", "eager"], + optimizer=[AdamOptimizer(), + GradientDescentOptimizer()])) + def testOptimizer(self, optimizer): + ... f(optimizer)... + +This will run `testOptimizer` 4 times with the specified optimizers: 2 in +Eager and 2 in Graph mode. +The test will be provided with arguments that match the arguments of combine +by name. It is necessary to request all arguments, except for `mode`, which is +optional. + +`combine()` function is available for creating a cross product of various +options. `times()` function exists for creating a product of N `combine()`-ed +results. See below. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import sys +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.optimizer_v2 import adam as adam_v2 +from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.util import tf_inspect + + +GPU_TEST = "test_gpu" in sys.argv[0] +TPU_TEST = "test_tpu" in sys.argv[0] + + +def generate(combinations): + """A decorator for generating test cases of a test method or a test class. + + Args: + combinations: a list of dictionaries created using combine() and times(). + + Restrictions: + -- there should always be a "mode" argument. Accepted values are "eager" + and "graph". + -- arguments of the test method must match by name to get the corresponding + value of the combination. Tests must accept all arguments (except "mode", + which is optional). + -- distribution argument is special. It is meant for passing instances of + DistributionStrategy. Each instance is to be passed as `(, + )` tuple, where is the number of required + GPUs. If the required number of GPUs for the DistributionStrategy isn't + available then the test case is going to be skipped. + + Returns: + a decorator that will cause the test method to be run under the specified + conditions. + + Raises: + ValueError - if "mode" argument wasn't either "eager" or "graph. + """ + + def decorator(test_function): + """The decorator to be returned.""" + + # Generate good test names that can be used with --test_filter. + for combination in combinations: + # We use OrderedDicts in `combine()` and `times()` to ensure stable + # order of keys in each dictionary. + assert isinstance(combination, OrderedDict) + name = "".join([ + "_{}_{}".format( + "".join(filter(str.isalnum, key)), + "".join(filter(str.isalnum, str(value)))) + for key, value in combination.items() + ]) + combination.update({"testcase_name": "_test{}".format(name)}) + + @parameterized.named_parameters(*combinations) + def decorated(self, **kwargs): + """A wrapped test method that sets up `test_function`.""" + assert "mode" in kwargs + mode = kwargs["mode"] + + if "distribution" in kwargs: + distribution = kwargs["distribution"] + kwargs["distribution"] = distribution.strategy + if distribution.required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not distribution.required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not distribution.required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < distribution.required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(distribution.required_gpus, context.num_gpus())) + + requested_arguments = tf_inspect.getfullargspec(test_function).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with context.eager_mode(), ops.Graph().as_default(): + test_function(**kwargs_to_pass) + elif mode == "graph": + with context.graph_mode(), ops.Graph().as_default(): + test_function(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + + return decorated + return decorator + + +def combine(**kwargs): + """Generate combinations based on its keyword arguments. + + Two sets of returned combinations can be concatenated using +. Their product + can be computed using `times()`. + + Args: + **kwargs: keyword arguments of form `option=[possibilities, ...]`. + + Returns: + a list of dictionaries for each combination. Keys in the dictionaries are + the keyword argument names. Each key has one value - one of the + corresponding keyword argument values. + """ + if not kwargs: + return [OrderedDict()] + + sort_by_key = lambda k: k[0][0] + kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) + first = list(kwargs.items())[0] + + rest = dict(list(kwargs.items())[1:]) + rest_combined = combine(**rest) + + key = first[0] + values = first[1] + + return [ + OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) + for v in values + for combined in rest_combined + ] + + +def times(*combined): + """Generate a product of N sets of combinations. + + times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) + + Args: + *combined: N lists of dictionaries that specify combinations. + + Returns: + a list of dictionaries for each combination. + + Raises: + ValueError: if some of the inputs have overlapping keys. + """ + assert combined + + if len(combined) == 1: + return combined[0] + + first = combined[0] + rest_combined = times(*combined[1:]) + + combined_results = [] + for a in first: + for b in rest_combined: + if set(a.keys()).intersection(set(b.keys())): + raise ValueError("Keys need to not overlap: {} vs {}".format( + a.keys(), b.keys())) + + combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) + return combined_results + + +class NamedObject(object): + """A class that translates an object into a good test name.""" + + def __init__(self, name, obj): + self._name = name + self._obj = obj + + def __getattr__(self, name): + return getattr(self._obj, name) + + def __call__(self, *args, **kwargs): + return self._obj(*args, **kwargs) + + def __repr__(self): + return self._name + + +class NamedDistribution(object): + """Translates DistributionStrategy and its data into a good name.""" + + def __init__(self, name, distribution, required_gpus=None, + required_tpu=False): + self._distribution = distribution + self._name = name + self._required_gpus = required_gpus + self._required_tpu = required_tpu + + def __repr__(self): + return self._name + + @property + def strategy(self): + return self._distribution + + @property + def required_gpus(self): + return self._required_gpus + + @property + def required_tpu(self): + return self._required_tpu + + +one_device_strategy = NamedDistribution( + "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + None) +tpu_strategy = NamedDistribution( + "TPU", tpu_strategy.TpuStrategy(), required_tpu=True) +mirrored_strategy_with_gpu_and_cpu = NamedDistribution( + "MirroredCPUAndGPU", + mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) +mirrored_strategy_without_prefetch = NamedDistribution( + "MirroredCPUAndGPUNoPrefetch", + mirrored_strategy.MirroredStrategy( + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1) +mirrored_strategy_with_two_gpus = NamedDistribution( + "Mirrored2GPUs", + mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + +adam_optimizer_v1_fn = NamedObject( + "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v1_fn = NamedObject( + "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) + +adam_optimizer_v2_fn = NamedObject( + "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) +gradient_descent_optimizer_v2_fn = NamedObject( + "GradientDescentV2", + lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) + +graph_and_eager_modes = ["graph", "eager"] + + +def distributions_and_v1_optimizers(): + """A common set of combination with DistributionStrategies and Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]) + + +def distributions_and_v2_optimizers(): + """DistributionStrategies and V2 Optimizers.""" + return combine( + distribution=[ + one_device_strategy, mirrored_strategy_with_gpu_and_cpu, + mirrored_strategy_with_two_gpus + ], + optimizer_fn=[adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]) diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py new file mode 100644 index 0000000000000000000000000000000000000000..219b24160f3902fcfa5363cc39a8fc5b30d00308 --- /dev/null +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for some testing utils from strategy_test_lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.eager import test + + +class TestingCombinationsTest(test.TestCase): + + def test_combine(self): + self.assertEqual([{ + "a": 1, + "b": 2 + }, { + "a": 1, + "b": 3 + }, { + "a": 2, + "b": 2 + }, { + "a": 2, + "b": 3 + }], combinations.combine(a=[1, 2], b=[2, 3])) + + def test_add(self): + self.assertEqual( + [{ + "a": 1 + }, { + "a": 2 + }, { + "b": 2 + }, { + "b": 3 + }], + combinations.combine(a=[1, 2]) + + combinations.combine(b=[2, 3])) + + def test_times(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1 + c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d1"), ("loss", "callable"), + ("mode", "eager")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "tensor"), + ("mode", "graph")]), + OrderedDict([("distribution", "d2"), ("loss", "callable"), + ("mode", "eager")]) + ], c4) + + def test_times_variable_arguments(self): + c1 = combinations.combine(mode=["graph", "eager"]) + c2 = combinations.combine(optimizer=["adam", "gd"]) + c3 = combinations.combine(distribution=["d1", "d2"]) + c4 = combinations.times(c3, c1, c2) + self.assertEqual([ + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d1"), ("mode", "eager"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "graph"), + ("optimizer", "gd")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "adam")]), + OrderedDict([("distribution", "d2"), ("mode", "eager"), + ("optimizer", "gd")]) + ], c4) + self.assertEqual( + combinations.combine( + mode=["graph", "eager"], + optimizer=["adam", "gd"], + distribution=["d1", "d2"]), c4) + + def test_overlapping_keys(self): + c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) + c2 = combinations.combine(mode=["eager"], loss=["callable"]) + with self.assertRaisesRegexp(ValueError, ".*Keys.+overlap.+"): + _ = combinations.times(c1, c2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..cff717db80f0bdd377b3c9c7e8ca3578ff273930 --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -0,0 +1,586 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes for different algorithms of reduction and broadcasting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.client import device_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import device_util + + +def _validate_destinations(destinations): + if not isinstance(destinations, + (value_lib.DistributedValues, six.string_types, list)): + raise ValueError("destinations must be one of a `DistributedValues` object," + " a device string, a list of device strings or None") + + if not destinations: + raise ValueError("destinations can not be empty") + + +def _validate_value_destination_pairs(value_destination_pairs): + # pylint: disable=g-missing-docstring + if not value_destination_pairs: return False + if not isinstance(value_destination_pairs, (list, tuple)): return False + if not all([isinstance(pair, tuple) for pair in value_destination_pairs]): + return False + if not all([isinstance(v[0], value_lib.PerDevice) + for v in value_destination_pairs]): + return False + return True + + +def _get_devices_from(destinations): + if isinstance(destinations, value_lib.DistributedValues): + return list(destinations.devices) + elif isinstance(destinations, six.string_types): + return [device_util.canonicalize(destinations)] + else: + return [ + device_util.canonicalize(destination) for destination in destinations + ] + + +def _devices_match(left, right): + return set(_get_devices_from(left)) == set(_get_devices_from(right)) + + +def _all_devices_match(value_destination_pairs): + if not all([d is None or _devices_match(v, d) + for v, d in value_destination_pairs]): + return False + if not all([_devices_match(v, value_destination_pairs[0][0]) + for v, _ in value_destination_pairs[1:]]): + return False + return True + + +def _simple_broadcast(tensor, destinations): + index = {} + devices = _get_devices_from(destinations) + for d in devices: + with ops.device(d): + index[d] = array_ops.identity(tensor) + return value_lib.Mirrored(index) + + +def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, + method_string): + # pylint: disable=g-missing-docstring + all_values = [] + count = 0 + for v in per_device_value._index.values(): # pylint: disable=protected-access + if isinstance(v, value_lib.MapOutput): + v_list = v.get() + if not v_list: + continue + count += len(v_list) + # Sum within each device before aggregating across devices. + v = math_ops.add_n(v_list) + else: + count += 1 + all_values.append(v) + if not all_values: + raise ValueError("`per_device_value` must be non-empty") + + with ops.device(reduce_to_device): + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + if method_string == "sum": + reduced = accumulation_fn(all_values) + elif method_string == "mean": + reduced = accumulation_fn(all_values) / count + else: + raise ValueError("`method_string` must be 'sum' or 'mean'") + return reduced + + +class CrossTowerOps(object): + """Base class for cross-tower reduction and broadcasting algorithms.""" + + def __init__(self): + pass + + def reduce(self, method_string, per_device_value, destinations=None): + """Reduce `per_device_value` to `destinations`. + + It runs the reduction operation defined by `method_string` and put the + result on `destinations`. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + per_device_value: a PerDevice object. + destinations: the reduction destinations. + + Returns: + a Mirrored object. + + Raises: + ValueError: if per_device_value is not a PerDevice object. + """ + if not isinstance(per_device_value, value_lib.PerDevice): + raise ValueError("`per_device_value` must be a `PerDevice` object.") + if destinations is not None: + _validate_destinations(destinations) + return self._reduce(method_string, per_device_value, destinations) + + def batch_reduce(self, method_string, value_destination_pairs): + """Reduce PerDevice objects in a batch. + + Reduce each first element in `value_destination_pairs` to each second + element which indicates the destinations. + + Args: + method_string: either 'sum' or 'mean' specifying the reduction method. + value_destination_pairs: a list or a tuple of tuples of PerDevice objects + and destinations. If a destination is None, then the destinations + are set to match the devices of the input PerDevice object. + + Returns: + a list of Mirrored objects. + + Raises: + ValueError: if `value_destination_pairs` is not a list or a tuple of + tuples of PerDevice objects and destinations + """ + if not _validate_value_destination_pairs(value_destination_pairs): + raise ValueError("`value_destination_pairs` must be a list or a tuple of " + "tuples of PerDevice objects and destinations") + for _, d in value_destination_pairs: + if d is not None: + _validate_destinations(d) + + return self._batch_reduce(method_string, value_destination_pairs) + + def broadcast(self, tensor, destinations): + """Broadcast the `tensor` to destinations. + + Args: + tensor: the tensor to broadcast. + destinations: the broadcast destinations. + + Returns: + a Mirrored object. + """ + _validate_destinations(destinations) + return self._broadcast(tensor, destinations) + + def _reduce(self, method_string, per_device_value, destinations): + raise NotImplementedError( + "_reduce method must be implemented in descendants.") + + def _batch_reduce(self, method_string, value_destination_pairs): + raise NotImplementedError( + "_batch_reduce method must be implemented in descendants.") + + def _broadcast(self, tensor, destinations): + return _simple_broadcast(tensor, destinations) + + +class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): + """Always do reduction to one device first and then do broadcasting. + + Batch reduction is done by reduction on each element one by one. + """ + + def __init__(self, reduce_to_device=None, accumulation_fn=math_ops.add_n): + """Constructor. + + Args: + reduce_to_device: the intermediate device to reduce to. If None, reduce + to the first device in `destinations` of the reduce() method. + accumulation_fn: a function that does accumulation. + """ + self.reduce_to_device = reduce_to_device + self.accumulation_fn = accumulation_fn + super(ReductionToOneDeviceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = self.reduce_to_device or devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + self.accumulation_fn, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + return [self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs] + + +def _group_value_by_device(per_device_values): + """Group values into sublists by their devices. + + This grouping is needed to call the all-reduce library. + + Args: + per_device_values: a list of PerDevice obejcts. + + Returns: + a list of lists, each sublist has components for its corresponding device of + PerDevice objects, paired with a None. + """ + destinations = per_device_values[0].devices + grouped = [[] for _ in range(len(destinations))] + for per_device_value in per_device_values: + # pylint: disable=protected-access + for i, v in enumerate(per_device_value._index.values()): + assert per_device_value.devices == destinations + grouped[i].append((v, None)) + return grouped + + +def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string): + """Ungroup results from all-reduce and make Mirrored objects. + + Each all-reduce result will be divided by the number of destinations before + Mirrored objects are created if method_string is "mean". + + Args: + grouped_reduced: a list of lists, each sublist has components for each + device, paired with a None. It is the result from + cross_tower_utils.aggregate_gradients_using*. + destinations: a list of device strings for returned Mirrored objects. + method_string: "mean" or "sum". + + Returns: + a list of Mirrored objects. + """ + index = [{} for _ in range(len(grouped_reduced[0]))] + for d, per_device_reduced in enumerate(grouped_reduced): + for i, (v, _) in enumerate(per_device_reduced): + if method_string == "mean": + index[i][destinations[d]] = v / len(destinations) + else: + index[i][destinations[d]] = v + return [value_lib.Mirrored(v) for v in index] + + +class ConcatAndSplitPacker(object): + """Concatenate and split tensors for reduction.""" + + def __init__(self, num_packs=1): + """Initialize the ConcatAndSplitPacker object. + + Args: + num_packs: specifies the number of split packs that will be + formed. + + Raises: + ValueError: if num_packs is not greater than 0. + """ + if num_packs <= 0: + raise ValueError("num_packs must be greater than zero.") + self.num_packs = num_packs + + def pack(self, grouped_grads_and_vars): + """Pack tensors.""" + self.grouped_grads_and_vars = grouped_grads_and_vars + self.all_tower_shapes = [] + self.all_tower_sizes = [] + + device_grad_packs = [] + for tower_grads_and_vars in grouped_grads_and_vars: + with ops.colocate_with(tower_grads_and_vars[0][0]): + # Flatten all the grads. + flat_grads = [ + array_ops.reshape(g, [-1]) for g, _ in tower_grads_and_vars + ] + # Remember the original shape of all the grads. + tower_shapes = [array_ops.shape(g) for g, _ in tower_grads_and_vars] + # Remember the original sizes of all the grads. + tower_sizes = [array_ops.size(g) for g, _ in tower_grads_and_vars] + # Concat all the flat grads into a big flat tensor. + concat_grads = array_ops.concat(flat_grads, 0) + + # Split the big tensor into num_splits packs. In cases where the + # total size is not divisible num_splits, the last pack gets + # more elements. + # TODO(zhengxq): it is also possible to optimize away all the concat + # as well. + num_splits = self.num_packs + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits + split_size_last = total_grad_size - split_size * (num_splits - 1) + split_sizes = [split_size] * (num_splits - 1) + [split_size_last] + grad_packs = array_ops.split(concat_grads, split_sizes) + + # Ready to aggregate the repacked gradients, with fake variables. + # TODO(zhengxq): It is hacky to have to use fake variables. + # We should remove the need for variables in + # aggregate_gradients_using*. + device_grad_packs.append(zip(grad_packs, [None] * num_splits)) + self.all_tower_shapes.append(tower_shapes) + self.all_tower_sizes.append(tower_sizes) + + return device_grad_packs + + def unpack(self, summed_device_grad_packs): + """Reverse the pack.""" + aggregated_device_grads = [] + for (summed_tower_grad_packs, + tower_grads_and_vars, tower_shapes, tower_sizes) in zip( + summed_device_grad_packs, self.grouped_grads_and_vars, + self.all_tower_shapes, self.all_tower_sizes): + # pylint: enable=line-too-long + # Reverse the packing operations in the previous steps. Form the + # summed gradients back into their original shapes. + with ops.colocate_with(summed_tower_grad_packs[0][0]): + # Form a list of the summed grad packs. + device_grad_packs = [g for g, _ in summed_tower_grad_packs] + + # Concat them back into a big flat tensor. + device_grads_concat = array_ops.concat(device_grad_packs, 0) + + # Split the tensors back into their original sizes. + grads_with_sizes = array_ops.split(device_grads_concat, tower_sizes) + + # Reshape the tensors back into their original shapes. + grads_with_shapes = [ + array_ops.reshape(grad, shape) + for shape, grad in zip(tower_shapes, grads_with_sizes) + ] + + # Form the list with the original list of variables. + summed_tower_grads = [ + (g, v) for g, (_, v) in zip(grads_with_shapes, tower_grads_and_vars) + ] + aggregated_device_grads.append(summed_tower_grads) + return aggregated_device_grads + + +class AggregateSmallTensorPacker(object): + """Concatenate small gradient tensors together for reduction.""" + + def __init__(self, + agg_small_grads_max_bytes=1048576, + agg_small_grads_max_group=16): + """Initialize the AggregateSmallTensorPacker object. + + Args: + agg_small_grads_max_bytes: largest tensor eligible for aggregation, + in number of bytes. + agg_small_grads_max_group: largest permitted aggregation of small + tensors. + + Raises: + ValueError: if `agg_small_grads_max_bytes` or `agg_small_grads_max_group` + is not greater than 0. + """ + if agg_small_grads_max_bytes <= 0 or agg_small_grads_max_group <= 0: + raise ValueError("agg_small_grads_max_bytes and agg_small_grads_max_group" + " should both be greater than zero.") + self.agg_small_grads_max_bytes = agg_small_grads_max_bytes + self.agg_small_grads_max_group = agg_small_grads_max_group + + def pack(self, grouped_grads_and_vars): + """Aggregate small tensors.""" + if (self.agg_small_grads_max_bytes > 0 and + self.agg_small_grads_max_group > 0): + tower_grads, self.packing = cross_tower_utils.pack_small_tensors( + grouped_grads_and_vars, + max_bytes=self.agg_small_grads_max_bytes, + max_group=self.agg_small_grads_max_group) + return tower_grads + + def unpack(self, summed_device_grad_packs): + """Reverse the aggregation process.""" + return cross_tower_utils.unpack_small_tensors(summed_device_grad_packs, + self.packing) + + +class AllReduceCrossTowerOps(CrossTowerOps): + """Reduction using all reduce.""" + + def __init__(self, + all_reduce_alg="nccl", + num_packs=1, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """All-reduce implementation of CrossTowerOps. + + Before performing all-reduce, tensors will be repacked or aggregated for + more efficient cross-device transportation: + 1) If `num_packs` is non-zero, pack values into + `num_packs` splits. + 2) Otherwise, if `agg_small_grads_max_bytes` > 0 and + `agg_small_grads_max_group` > 0, aggregate values smaller than + `agg_small_grads_max_bytes` into groups with at most + `agg_small_grads_max_group` values. + 3) Otherwise, no repacking or grouping will happen. + + Args: + all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or + "hierarchical_copy" are supported. + num_packs: see above. + agg_small_grads_max_bytes: see above. + agg_small_grads_max_group: see above. + tensors. + """ + self.all_reduce_alg = all_reduce_alg + self.num_packs = num_packs + self.agg_small_grads_max_bytes = agg_small_grads_max_bytes + self.agg_small_grads_max_group = agg_small_grads_max_group + super(AllReduceCrossTowerOps, self).__init__() + + def _reduce(self, method_string, per_device_value, destinations): + if ((destinations is None or _devices_match(per_device_value, destinations)) + and not context.executing_eagerly()): + return self._batch_all_reduce(method_string, [per_device_value])[0] + else: + devices = _get_devices_from(destinations or per_device_value) + reduce_to_device = devices[0] + reduced = _simple_reduce(per_device_value, reduce_to_device, + math_ops.add_n, method_string) + return self.broadcast(reduced, devices) + + def _batch_reduce(self, method_string, value_destination_pairs): + if (_all_devices_match(value_destination_pairs) and + not context.executing_eagerly()): + return self._batch_all_reduce(method_string, + [v[0] for v in value_destination_pairs]) + else: + if not context.executing_eagerly(): + logging.warning("Efficient batch_reduce is not supported if " + "destinations are different.") + return [ + self._reduce(method_string, t, destinations=v) + for t, v in value_destination_pairs + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + destinations = per_device_values[0].devices + grouped = _group_value_by_device(per_device_values) + if self.num_packs > 0: + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s and num_packs = %d", len(per_device_values), + self.all_reduce_alg, self.num_packs) + tensor_packer = ConcatAndSplitPacker(self.num_packs) + device_grad_packs = tensor_packer.pack(grouped) + elif (self.agg_small_grads_max_bytes > 0 and + self.agg_small_grads_max_group > 0): + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self.all_reduce_alg, self.agg_small_grads_max_bytes, + self.agg_small_grads_max_group) + tensor_packer = AggregateSmallTensorPacker( + self.agg_small_grads_max_bytes, self.agg_small_grads_max_group) + device_grad_packs = tensor_packer.pack(grouped) + else: + logging.info( + "batch_all_reduce invoked for batches size = %d with algorithm = %s", + len(per_device_values), self.all_reduce_alg) + tensor_packer = None + device_grad_packs = grouped + + # The actual aggregation of the repacked gradients. Note that they are + # sharded among different aggregation trees. So it is important to strike + # the balance on num_splits. + if self.all_reduce_alg == "nccl": + reduced = cross_tower_utils.aggregate_gradients_using_nccl( + device_grad_packs) + else: + # TODO(yuefengz): check that gpu ids in `destinations` are in ascending + # order. + reduced = ( + cross_tower_utils.aggregate_gradients_using_hierarchical_copy( + destinations, device_grad_packs)) + + if tensor_packer: + reduced = tensor_packer.unpack(reduced) + + return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, + method_string) + + +_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + + +def _has_dgx1_like_links(gpu_links): + if not gpu_links: + return False + # TODO(yuefengz): figure out the right topology for hierarchial copy if + # number of gpus are less than 8. + if len(gpu_links) < 8: + return False + for i, (gpu_link, dgx1_link) in enumerate(zip(gpu_links, _dgx1_links)): + if (set(gpu_link) != set(dgx1_link) and + set(gpu_link) != set(dgx1_link + [i])): + return False + return True + + +def _choose_all_reduce_algorithm(device_links): + if _has_dgx1_like_links(device_links): + logging.info("Configured hierarchical_copy with num_packs=%d", + len(device_links)) + return AllReduceCrossTowerOps( + "hierarchical_copy", num_packs=len(device_links)) + else: + logging.info("Configured nccl all-reduce.") + return AllReduceCrossTowerOps("nccl", num_packs=1) + + +def choose_the_best(devices, session_config=None): + """Find the best subclass of CrossTowerOps given a tensorflow session. + + Args: + devices: a list of devices passed for distribute strategy. + session_config: a tensorflow session config or None. If None, it will make + deciesion based on all local devices. + + Returns: + a subclass of CrossTowerOps. + """ + requested_devices = set([device_util.canonicalize(d) for d in devices]) + machine_devices = device_lib.list_local_devices(session_config=session_config) + using_devices = [] + for d in machine_devices: + if device_util.canonicalize(d.name) in requested_devices: + using_devices.append(d) + else: + logging.info( + "Device is available but not used by distribute strategy: %s", d.name) + + if len(using_devices) != len(requested_devices): + logging.warning("Not all devices in distribute strategy are visible by " + "TensorFlow sessions.") + return ReductionToOneDeviceCrossTowerOps() + + if any([d.device_type.lower() != "gpu" for d in using_devices]): + logging.warning("Not all devices in DistributionStrategy are visible to " + "TensorFlow session.") + return ReductionToOneDeviceCrossTowerOps() + + device_links = [[] for _ in range(len(using_devices))] + for i, device in enumerate(using_devices): + for link in device.locality.links.link: + device_links[i].append(link.device_id) + + return _choose_all_reduce_algorithm(device_links) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7b0870887465ec2fe40007695d099277db38bf --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -0,0 +1,221 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CrossTowerOps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _make_per_device(values, devices): + devices = cross_tower_ops_lib._get_devices_from(devices) + assert len(values) == len(devices) + index = {} + for d, v in zip(devices, values): + with ops.device(d): + placed_v = array_ops.identity(v) + index[d] = placed_v + return value_lib.PerDevice(index) + + +# pylint: disable=g-doc-args,g-doc-return-or-yield +def _fake_mirrored(value, devices): + """Create a faked Mirrored object for testing. + + All components of the returned Mirrored have the same objects, which is not + true in reality. + """ + devices = cross_tower_ops_lib._get_devices_from(devices) + return value_lib.Mirrored( + {d: v for d, v in zip(devices, [value] * len(devices))}) + + +_cpu_device = "/device:CPU:0" + + +class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): + + def _assert_value_equal(self, left, right): + if isinstance(left, list): + for l, r in zip(left, right): + self._assert_value_equal(l, r) + else: + self.assertEqual(type(left), type(right)) + self.assertEqual(left.devices, right.devices) + if context.executing_eagerly(): + self.assertEqual([v.numpy() for v in left._index.values()], + list(right._index.values())) + else: + with self.test_session() as sess: + self.assertEqual( + sess.run(list(left._index.values())), list(right._index.values())) + + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + devices = distribution.worker_devices + + values = [constant_op.constant(float(d)) for d in range(len(devices))] + per_device = _make_per_device(values, devices) + mean = (len(devices) - 1.) / 2. + + values_2 = [constant_op.constant(d + 1.0) for d in range(len(devices))] + per_device_2 = _make_per_device(values_2, devices) + mean_2 = mean + 1. + + destination_mirrored = _fake_mirrored(1., devices) + destination_different = _fake_mirrored(1., _cpu_device) + destination_str = _cpu_device + destination_list = devices + + all_destinations = [ + None, destination_mirrored, destination_different, destination_str, + destination_list + ] + + # test reduce() + for destinations in all_destinations: + self._assert_value_equal( + cross_tower_ops.reduce("mean", per_device, destinations=destinations), + _fake_mirrored(mean, destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce( + "mean", per_device_2, destinations=destinations), + _fake_mirrored(mean_2, destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce("sum", per_device, destinations=destinations), + _fake_mirrored(mean * len(devices), destinations or per_device)) + self._assert_value_equal( + cross_tower_ops.reduce( + "sum", per_device_2, destinations=destinations), + _fake_mirrored(mean_2 * len(devices), destinations or per_device)) + + # test batch_reduce() + for d1, d2 in itertools.product(all_destinations, all_destinations): + self._assert_value_equal( + cross_tower_ops.batch_reduce( + "mean", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean, d1 or per_device), + _fake_mirrored(mean_2, d2 or per_device_2)]) + self._assert_value_equal( + cross_tower_ops.batch_reduce( + "sum", [(per_device, d1), (per_device_2, d2)]), + [_fake_mirrored(mean * len(devices), d1 or per_device), + _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)]) + + # test broadcast() + for destinations in all_destinations: + if destinations is None: + continue + else: + self._assert_value_equal( + cross_tower_ops.broadcast(constant_op.constant(1.), destinations), + _fake_mirrored(1., destinations)) + + def testChooseAlgorithm(self): + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], + [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "hierarchical_copy") + self.assertEqual(result.num_packs, 8) + + # if there are only 4 devices + device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "nccl") + self.assertEqual(result.num_packs, 1) + + # if devices links contain each device itself + device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], + [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], + [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "hierarchical_copy") + self.assertEqual(result.num_packs, 8) + + # if not dgx1-like links + device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], + [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] + result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) + self.assertTrue( + isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) + self.assertEqual(result.all_reduce_alg, "nccl") + self.assertEqual(result.num_packs, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc04e2195f6d305e0f7c642f24c355286f1a8cfa --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -0,0 +1,339 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for cross_tower_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections as pycoll + +from tensorflow.contrib import nccl +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def aggregate_gradients_using_nccl(tower_grads): + """Aggregate gradients using nccl allreduce.""" + agg_all_g_and_v = [] + for single_g_and_v in zip(*tower_grads): + single_grads = [g for g, _ in single_g_and_v] + agg_grads = nccl.all_sum(single_grads) + agg_all_g_and_v.append( + [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) + + agg_all_g_and_v = list(zip(*agg_all_g_and_v)) + + return agg_all_g_and_v + + +def aggregate_gradients_using_hierarchical_copy(avail_devices, tower_grads): + """Aggregate gradients using hierarchical copies. + + Args: + avail_devices: available GPU devices. + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over towers. The inner list is over individual gradients. + + Returns: + The list of (aggregated_gradient, variable), where the gradient has been + summed across all towers and the variable is chosen from the first tower. + """ + # This only works for DGX-1 type of machine topology + # Device peer to peer matrix + # DMA: 0 1 2 3 4 5 6 7 + # 0: Y Y Y Y Y N N N + # 1: Y Y Y Y N Y N N + # 2: Y Y Y Y N N Y N + # 3: Y Y Y Y N N N Y + # 4: Y N N N Y Y Y Y + # 5: N Y N N Y Y Y Y + # 6: N N Y N Y Y Y Y + # 7: N N N Y Y Y Y Y + agg_grads = [] + num_devices = len(avail_devices) + # In the special case of DGX-1 machine topology, the two groups have equal + # size. + group_size = num_devices // 2 + for i, single_grads in enumerate(zip(*tower_grads)): + group_0_main_device = i % num_devices + group_1_main_device = (group_0_main_device + group_size) % num_devices + if group_0_main_device < group_size: + group_0_begin = 0 + group_1_begin = group_size + else: + group_0_begin = group_size + group_1_begin = 0 + + # Aggregate the first group. + group_0_device_grads = single_grads[group_0_begin: + group_0_begin + group_size] + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads, _ = aggregate_single_gradient_using_copy( + group_0_device_grads, False, False) + + # Aggregate the second group. + group_1_device_grads = single_grads[group_1_begin: + group_1_begin + group_size] + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads, _ = aggregate_single_gradient_using_copy( + group_1_device_grads, False, False) + + # Aggregate between the groups. + with ops.device(avail_devices[group_0_main_device]): + (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( + [group_0_agg_grads, group_1_agg_grads], False, False) + + # Broadcast the result back into the root of each group. + with ops.device(avail_devices[group_0_main_device]): + group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) + with ops.device(avail_devices[group_1_main_device]): + group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) + + agg_grads_bcast = [] + for j in range(len(single_grads)): + with ops.device(avail_devices[j]): + # Broadcast the result back to each member in the group from the root. + if (group_0_main_device < group_size) == (j < group_size): + src_device_grad = group_0_agg_grads_bcast + else: + src_device_grad = group_1_agg_grads_bcast + agg_grads_bcast.append(array_ops.identity(src_device_grad)) + + agg_grads.append( + [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) + + agg_grads = list(zip(*agg_grads)) + + return agg_grads + + +def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, + check_inf_nan): + """Calculate the average gradient for a shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + grad_and_vars: A list or tuple of (gradient, variable) tuples. Each + (gradient, variable) pair within the outer list represents the gradient + of the variable calculated for a single tower, and the number of pairs + equals the number of towers. + use_mean: if True, mean is taken, else sum of gradients is taken. + check_inf_nan: check grads for nans and infs. + + Returns: + The tuple ([(average_gradient, variable),], has_nan_or_inf) where the + gradient has been averaged across all towers. The variable is chosen from + the first tower. The has_nan_or_inf indicates the grads has nan or inf. + """ + grads = [g for g, _ in grad_and_vars] + grad = math_ops.add_n(grads) + + if use_mean and len(grads) > 1: + grad = array_ops.multiply(grad, 1.0 / len(grads)) + + v = grad_and_vars[0][1] + if check_inf_nan: + has_nan_or_inf = array_ops.logical_not( + array_ops.reduce_all(array_ops.is_finite(grads))) + return (grad, v), has_nan_or_inf + else: + return (grad, v), None + + +def extract_ranges(index_list, range_size_limit=32): + """Extract consecutive ranges and singles from index_list. + + Args: + index_list: List of monotone increasing non-negative integers. + range_size_limit: Largest size range to return. If a larger + consecutive range exists, it will be returned as multiple + ranges. + + Returns: + (ranges, singles) where ranges is a list of [first, last] pairs of + consecutive elements in index_list, and singles is all of the + other elements, in original order. + """ + if not index_list: + return [], [] + first = index_list[0] + last = first + ranges = [] + singles = [] + for i in index_list[1:]: + if i == last + 1 and (last - first) <= range_size_limit: + last = i + else: + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + first = i + last = i + if last > first: + ranges.append([first, last]) + else: + singles.append(first) + return ranges, singles + + +GradPackTuple = pycoll.namedtuple('GradPackTuple', 'indices vars shapes') + + +def pack_range(key, packing, grad_vars, rng): + """Form the concatenation of a specified range of gradient tensors. + + Args: + key: Value under which to store meta-data in packing that will be used + later to restore the grad_var list structure. + packing: Dict holding data describing packed ranges of small tensors. + grad_vars: List of (grad, var) pairs for one tower. + rng: A pair of integers giving the first, last indices of a consecutive + range of tensors to be packed. + + Returns: + A tensor that is the concatenation of all the specified small tensors. + """ + to_pack = grad_vars[rng[0]:rng[1] + 1] + members = [] + variables = [] + restore_shapes = [] + with ops.name_scope('pack'): + for g, v in to_pack: + variables.append(v) + restore_shapes.append(g.shape) + with ops.device(g.device): + members.append(array_ops.reshape(g, [-1])) + packing[key] = GradPackTuple( + indices=range(rng[0], rng[1] + 1), + vars=variables, + shapes=restore_shapes) + with ops.device(members[0].device): + return array_ops.concat(members, 0) + + +def unpack_grad_tuple(gv, gpt): + """Unpack a previously packed collection of gradient tensors. + + Args: + gv: A (grad, var) pair to be unpacked. + gpt: A GradPackTuple describing the packing operation that produced gv. + + Returns: + A list of (grad, var) pairs corresponding to the values that were + originally packed into gv, maybe following subsequent operations like + reduction. + """ + elt_widths = [x.num_elements() for x in gpt.shapes] + with ops.device(gv[0][0].device): + with ops.name_scope('unpack'): + splits = array_ops.split(gv[0], elt_widths) + unpacked_gv = [] + for idx, s in enumerate(splits): + unpacked_gv.append((array_ops.reshape(s, gpt.shapes[idx]), + gpt.vars[idx])) + return unpacked_gv + + +def pack_small_tensors(tower_grads, max_bytes=0, max_group=0): + """Concatenate small gradient tensors together for reduction. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. + max_bytes: Int giving max number of bytes in a tensor that + may be considered small. + max_group: Int giving max number of small tensors that may be + concatenated into one new tensor. + + Returns: + new_tower_grads, packing where new_tower_grads is identical to + tower_grads except that all feasible small_tensors have been removed + from their places and concatenated into larger tensors that are + now in the front of the list for each tower, and packing contains + the data necessary to restore the tower_grads structure. + + Look through the first tower for gradients of the same type (float), + and small size, that are all sequential. For each such group, + replace by a new tensor that is a flattened concatenation. Note + that the corresponding variable will be absent, which doesn't matter + because it isn't used during all-reduce. + + Requires: + Every gv_list in towers must have isomorphic structure including identical + tensor sizes and types. + """ + small_indices = [] + large_indices = [] + for idx, (g, _) in enumerate(tower_grads[0]): + if g.dtype == dtypes.float32 and (4 * g.shape.num_elements()) <= max_bytes: + small_indices.append(idx) + else: + large_indices.append(idx) + small_ranges, small_singles = extract_ranges( + small_indices, range_size_limit=max_group) + large_indices = sorted(large_indices + small_singles) + num_gv = len(tower_grads[0]) + packing = {} + if small_ranges: + new_tower_grads = [] + for dev_idx, gv_list in enumerate(tower_grads): + assert len(gv_list) == num_gv + new_gv_list = [] + for r in small_ranges: + key = '%d:%d' % (dev_idx, len(new_gv_list)) + new_gv_list.append((pack_range(key, packing, gv_list, r), + 'packing_var_placeholder')) + for i in large_indices: + new_gv_list.append(gv_list[i]) + new_tower_grads.append(new_gv_list) + return new_tower_grads, packing + else: + return tower_grads, None + + +def unpack_small_tensors(tower_grads, packing): + """Undo the structure alterations to tower_grads done by pack_small_tensors. + + Args: + tower_grads: List of List of (grad, var) tuples. + packing: A dict generated by pack_small_tensors describing the changes + it made to tower_grads. + + Returns: + new_tower_grads: identical to tower_grads except that concatenations + of small tensors have been split apart and returned to their original + positions, paired with their original variables. + """ + if not packing: + return tower_grads + new_tower_grads = [] + num_devices = len(tower_grads) + num_packed = len(packing.keys()) // num_devices + for dev_idx, gv_list in enumerate(tower_grads): + gv_list = list(gv_list) + new_gv_list = gv_list[num_packed:] + for i in xrange(0, num_packed): + k = '%d:%d' % (dev_idx, i) + gpt = packing[k] + gv = unpack_grad_tuple(gv_list[i], gpt) + for gi, idx in enumerate(gpt.indices): + assert idx == gpt.indices[gi] + new_gv_list.insert(idx, gv[gi]) + new_tower_grads.append(new_gv_list) + return new_tower_grads diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a520ab5aeafb932092ebbbaaf07480cf40403b --- /dev/null +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -0,0 +1,127 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show that DistributionStrategy works with canned Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.optimizer_v2 import adagrad +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.summary.writer import writer_cache + + +class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, + parameterized.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def dataset_input_fn(self, x, y, batch_size, shuffle): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + if shuffle: + dataset = dataset.shuffle(batch_size) + dataset = dataset.repeat(10).batch(batch_size) + return dataset + + return input_fn + + @combinations.generate( + combinations.combine( + mode=['graph'], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu + ])) + def test_complete_flow_with_mode(self, distribution): + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + train_input_fn = self.dataset_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size // len(distribution.worker_devices), + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, y=data, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, batch_size=batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + estimator = dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=label_dimension, + model_dir=self._model_dir, + # TODO(isaprykin): Work around the colocate_with error. + dnn_optimizer=adagrad.AdagradOptimizer(0.001), + linear_optimizer=adagrad.AdagradOptimizer(0.001), + config=run_config.RunConfig(train_distribute=distribution)) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..cbfd17850212a1c007e2edb9dd3986b3109f040d --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -0,0 +1,30 @@ +# Example TensorFlow models that use DistributionStrategy for training. + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "simple_estimator_example", + srcs = ["simple_estimator_example.py"], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "simple_tfkeras_example", + srcs = [ + "simple_tfkeras_example.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py new file mode 100644 index 0000000000000000000000000000000000000000..00c25c7a2482a559c8b94ff3be86c4961dfb439f --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple example to test the a DistributionStrategy with Estimators. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def build_model_fn_optimizer(): + """Simple model_fn with optimizer.""" + # TODO(anjalisridhar): Move this inside the model_fn once OptimizerV2 is + # done? + optimizer = tf.train.GradientDescentOptimizer(0.2) + + def model_fn(features, labels, mode): # pylint: disable=unused-argument + """model_fn which uses a single unit Dense layer.""" + # You can also use the Flatten layer if you want to test a model without any + # weights. + layer = tf.layers.Dense(1, use_bias=True) + logits = layer(features) + + if mode == tf.estimator.ModeKeys.PREDICT: + predictions = {"logits": logits} + return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + def loss_fn(): + y = tf.reshape(logits, []) - tf.constant(1.) + return y * y + + if mode == tf.estimator.ModeKeys.EVAL: + return tf.estimator.EstimatorSpec(mode, loss=loss_fn()) + + assert mode == tf.estimator.ModeKeys.TRAIN + + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss_fn(), global_step=global_step) + return tf.estimator.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op) + + return model_fn + + +def main(_): + distribution = tf.contrib.distribute.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + config = tf.estimator.RunConfig(train_distribute=distribution) + + def input_fn(): + features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + labels = tf.data.Dataset.from_tensors([1.]).repeat(10) + return tf.data.Dataset.zip((features, labels)) + + estimator = tf.estimator.Estimator( + model_fn=build_model_fn_optimizer(), config=config) + estimator.train(input_fn=input_fn, steps=10) + + eval_result = estimator.evaluate(input_fn=input_fn) + print("Eval result: {}".format(eval_result)) + + def predict_input_fn(): + predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10) + return predict_features + + predictions = estimator.predict(input_fn=predict_input_fn) + # TODO(anjalsridhar): This returns a generator object, figure out how to get + # meaningful results here. + print("Prediction results: {}".format(predictions)) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py new file mode 100644 index 0000000000000000000000000000000000000000..b87224251ca3844fc81c6f32a893d2c71664a955 --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An example tf.keras model that is trained using MirroredStrategy.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from sys import argv +import numpy as np +import tensorflow as tf + + +def input_fn(): + x = np.random.random((1024, 10)) + y = np.random.randint(2, size=(1024, 1)) + x = tf.cast(x, tf.float32) + dataset = tf.data.Dataset.from_tensor_slices((x, y)) + dataset = dataset.repeat(10) + dataset = dataset.batch(32) + return dataset + + +def main(args): + if len(args) < 2: + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example./') + return + + print('Using %s to store checkpoints.' % args[1]) + + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) + optimizer = tf.train.GradientDescentOptimizer(0.2) + + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) + model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + + model.compile(loss='binary_crossentropy', optimizer=optimizer) + model.summary() + tf.keras.backend.set_learning_phase(True) + keras_estimator = tf.keras.estimator.model_to_estimator( + keras_model=model, config=config, model_dir=args[1]) + + keras_estimator.train(input_fn=input_fn, steps=10) + eval_result = keras_estimator.evaluate(input_fn=input_fn) + print('Eval result: {}'.format(eval_result)) + +if __name__ == '__main__': + tf.app.run(argv=argv) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4219d54cbd43032e9c5a0ee2e76fe025045ba21b --- /dev/null +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -0,0 +1,306 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.ops.losses import losses_impl + + +class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]), + combinations.combine(is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[combinations.adam_optimizer_v1_fn], + mode=["graph"], + use_callable_loss=[False], + is_tpu=[True])) + def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, + is_tpu): + with distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=use_callable_loss) + + # TODO(isaprykin): Eliminate `is_tpu`. Probably add a + # `DistributionStrategy.create_monitor` so that each DistributionStrategy + # could influence its training loop. That method would return an instance + # of Monitor. TPUMonitor would execute tpu.initialize_system() and + # tpu.shutdown_system(). + if is_tpu: + dataset = dataset.batch(2) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + # TODO(isaprykin): Make iterator get_next() return a list of sub- + # batches for each iteration. Pass iterator.get_next() and not iterator + # to call_for_each_tower. + return distribution.group( + distribution.call_for_each_tower( + model_fn, + iterator.get_next() if not is_tpu else iterator, + run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers() + + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph", "eager"]))) + def testOptimizerInsideModelFn(self, distribution, optimizer_fn): + created_variables = [] + trainable_variables = [] + + def appending_creator(next_creator, *args, **kwargs): + v = next_creator(*args, **kwargs) + created_variables.append(v.name) + if "trainable" in kwargs and kwargs["trainable"]: + trainable_variables.append(v.name) + return v + + # Creator scope needs to be set before it's used inside + # `distribution.scope`. + with variable_scope.variable_creator_scope( + appending_creator), distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, + use_bias=True, + use_callable_loss=True, + create_optimizer_inside_model_fn=True) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + def get_expected_variables(optimizer_fn, num_parameter_devices): + variables_map = { + "GradientDescent": ["dense/kernel", "dense/bias"], + "Adam": [ + "dense/kernel", "dense/bias", "beta1_power", "beta2_power", + "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", + "dense/bias/Adam_1" + ] + } + variables = variables_map[optimizer_fn().get_name()] + variables.extend([ + v + "/replica_{}".format(replica) + for v in variables + for replica in range(1, num_parameter_devices) + ]) + return set([v + ":0" for v in variables]) + + self.assertEqual( + get_expected_variables(optimizer_fn, + len(distribution.parameter_devices)), + set(created_variables)) + + @combinations.generate( + combinations.times(combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + momentum=[0.8, 0.9, 0.99], + renorm=[False, True]))) + def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, + renorm): + """Verifies that moving mean updates are reduced across towers.""" + with distribution.scope(): + num_towers = len(distribution.worker_devices) + model_fn, dataset, batchnorm = batchnorm_example( + optimizer_fn, + batch_per_epoch=num_towers, + momentum=momentum, + renorm=renorm) + + # Disable prefetching since that makes the specific input on each device + # to be non deterministic, and this test relies on specific input being + # on each device. + if isinstance(distribution, mirrored_strategy.MirroredStrategy): + distribution._prefetch_on_device = False + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return control_flow_ops.group( + distribution.unwrap( + distribution.call_for_each_tower( + model_fn, + iterator.get_next(), + run_concurrently=batchnorm.built)) + + ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + expected_moving_means = [0.] * 8 + + def averaged_batch_mean(i): + # Each batch has shape [16, 8] where the ith element in jth list is + # (8 * j + i + tower_id * 100). So the batch mean in each tower is + # (60 + i + tower_id * 100). So here comes its batch mean over all + # towers: + return 60. + i + (num_towers - 1.) / 2. * 100. + + for _ in range(10): + run_step() + moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + + # We make sure that the moving_mean is updated as if the sample mean is + # calculated over all towers. + for i, expected_moving_mean in enumerate(expected_moving_means): + expected_moving_means[i] -= (( + expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) + self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + + @combinations.generate( + combinations.times( + combinations.combine( + distribution=[combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn], + loss_reduction=[losses_impl.Reduction.SUM, + losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, + use_callable_loss): + with distribution.scope(): + all_vars = [] + + def model_fn(x, y): + + def loss_fn(): + # Use fixed initialization to make the steps deterministic. + w = variable_scope.get_variable("w", initializer=[[2.]]) + all_vars.append(w) + predict = math_ops.matmul(x, w) + return losses_impl.mean_squared_error( + y, predict, reduction=loss_reduction) + + optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) + labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) + dataset = dataset_ops.Dataset.zip((features, labels)).repeat() + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return distribution.group( + distribution.call_for_each_tower( + model_fn, *iterator.get_next(), run_concurrently=False)) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables_lib.global_variables_initializer()) + + run_step() + + self.assertEqual(distribution.num_towers, len(all_vars)) + v = all_vars[0] + self.assertTrue(all([v is vi for vi in all_vars[1:]])) + weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + # Our model is: + # predict = x * w + # loss = (predict - y)^2 + # dloss/dpredict = 2*(predict - y) + # dloss/dw = 2 * x^T @ (predict - y) + # For our batch size of 2, assuming sum loss reduction: + # x = [2, 7] + # y = [6, 21] + # w_initial = 2 + # predict = [4, 14] + # predict - y = [-2, -7] + # dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106 + # So unreplicated the update to w with lr=0.2 is -0.2 * -106 = 21.2 + # with sum loss reduction, or 10.6 with mean. + if loss_reduction == losses_impl.Reduction.SUM: + # Note that the "distribution.num_towers" factor will go away once + # we split the input across towers, instead of pulling a complete + # batch of input per tower. + self.assertNear(weight, 2 + 21.2 * distribution.num_towers, 0.0001) + else: + # One of the mean loss reductions. + self.assertNear(weight, 2 + 10.6, 0.0001) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0edb3a11df7788991ca14f957494d87593a449 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -0,0 +1,497 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class MirroredStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import six + +from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.contrib.distribute.python import values +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.eager import context +from tensorflow.python.eager import tape +from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import coordinator +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +def _cpu_device(device): + cpu_device = tf_device.DeviceSpec.from_string(device) + cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) + return cpu_device.to_string() + + +class _RequestedStop(Exception): + pass + + +class MirroredStrategy(distribute_lib.DistributionStrategy): + """Mirrors vars to distribute across multiple devices on a single machine. + + This strategy uses one tower per device and sync replication. + """ + + def __init__(self, + devices=None, + num_gpus=None, + cross_tower_ops=None, + prefetch_on_device=None): + super(MirroredStrategy, self).__init__() + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: + if num_gpus is None: + num_gpus = context.num_gpus() + devices = ["/device:GPU:%d" % d for d in range(num_gpus)] + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + + assert devices, "Must specify at least one device." + assert len(set(devices)) == len(devices), ( + "No duplicates allowed in `devices` argument.") + # TODO(josh11b): Require at least 2 devices? + self._devices = devices + self._canonical_device_set = set( + [device_util.canonicalize(d) for d in devices]) + self._device_index = values.PerDevice( + dict((d, i) for i, d in enumerate(devices))) + self._cross_tower_ops = cross_tower_ops + self._prefetch_on_device = prefetch_on_device + + def _create_variable(self, next_creator, *args, **kwargs): + """Create a mirrored variable. See `DistributionStrategy.scope`.""" + # Figure out what collections this variable should be added to. + # We'll add the MirroredVariable to those collections instead. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + colocate_with = kwargs.pop("colocate_with", None) + devices = self._get_devices_from(colocate_with) + + tower_local = kwargs.pop("tower_local_reduce_method", None) + if tower_local is not None: + kwargs["trainable"] = False + + # TODO(josh11b,apassos): It would be better if variable initialization + # was never recorded on the tape instead of having to do this manually + # here. + with tape.stop_recording(): + index = {} + for i, d in enumerate(devices): + with ops.device(d): + if i > 0: + # Give replicas meaningful distinct names: + var0name = index[devices[0]].name.split(":")[0] + kwargs["name"] = "%s/replica_%d" % (var0name, i) + # Initialize replicas with the same value: + if context.executing_eagerly(): + initial_value = index[devices[0]].value() + else: + initial_value = index[devices[0]].initial_value + kwargs["initial_value"] = array_ops.identity(initial_value) + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + v = next_creator(*args, **kwargs) + assert not isinstance(v, values.DistributedVariable) + index[d] = v + + if tower_local is None: + result = values.MirroredVariable(index, index[devices[0]]) + else: + result = values.TowerLocalVariable( + index, index[devices[0]], tower_local) + + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the member variables + # to the TRAINABLE_VARIABLES collection, so we manually remove + # them and replace with the MirroredVariable. We can't set + # "trainable" to False for next_creator() since that causes functions + # like implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + for v in index.values(): + l.remove(v) + g.add_to_collections(collections, result) + return result + + def distribute_dataset(self, dataset): + per_device_dataset = values.PerDeviceDataset( + dataset, self._devices, self._prefetch_on_device) + return per_device_dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + # TODO(josh11b): In eager mode, use one thread per device, or async mode. + return self._get_cross_tower_ops().broadcast(tensor, destinations or + self._devices) + + def _call_for_each_tower(self, fn, *args, **kwargs): + """Run `fn` in separate threads, once per tower/worker device. + + Args: + fn: function to run (will be run once per device, each in its own thread). + *args: positional arguments for `fn` + **kwargs: keyword arguments for `fn`. + `"run_concurrently"`: Boolean indicating whether executions of `fn` + can be run concurrently (under eager execution only), defaults to + `True`. + + Returns: + Merged return value of `fn` across all towers. + + Raises: + RuntimeError: If fn() calls get_tower_context().merge_call() a different + number of times for when called for different devices. + """ + run_concurrently = kwargs.pop("run_concurrently", True) + if not context.executing_eagerly(): + # Lots of TF library code isn't thread-safe in graph mode, and + # there is little to be gained by turning on multithreading when + # constructing a graph. + run_concurrently = False + # Needed for per-thread device, etc. contexts in graph mode. + ops.get_default_graph().switch_to_thread_local() + elif run_concurrently is None: + run_concurrently = True + + coord = coordinator.Coordinator( + clean_stop_exception_types=(_RequestedStop,)) + + shared_variable_store = {} + + # TODO(isaprykin): Create these threads once instead of during every run() + # call. + threads = [] + for index, d in enumerate(self._devices): + variable_creator_fn = shared_variable_creator.make_fn( + shared_variable_store, index) + t = MirroredStrategy._MirroredTowerThread( + self, coord, d, variable_creator_fn, fn, + *values.select_device(d, args), **values.select_device(d, kwargs)) + threads.append(t) + + for t in threads: + t.start() + + # When `fn` starts `should_run` event is set on _MirroredTowerThread + # (`MTT`) threads. The execution waits until + # `MTT.has_paused` is set, which indicates that either `fn` is + # complete or a `get_tower_context().merge_call()` is called. If `fn` is + # complete, then `MTT.done` is set to True. Otherwise, arguments + # of `get_tower_context().merge_call` from all paused threads are grouped + # and the `merge_fn` is performed. Results of the + # `get_tower_context().merge_call` are then set to `MTT.merge_result`. + # Each such `get_tower_context().merge_call` call returns the + # `MTT.merge_result` for that thread when `MTT.should_run` event + # is reset again. Execution of `fn` resumes. + + try: + with coord.stop_on_exception(): + all_done = False + while not all_done and not coord.should_stop(): + done = [] + if run_concurrently: + for t in threads: + t.should_run.set() + for t in threads: + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + else: + for t in threads: + t.should_run.set() + t.has_paused.wait() + t.has_paused.clear() + if coord.should_stop(): + return None + done.append(t.done) + if coord.should_stop(): + return None + all_done = all(done) + if not all_done: + if any(done): + raise RuntimeError("Some towers made a different number of " + "tower_context().merge_call() calls.") + # get_tower_context().merge_call() case + merge_args = values.regroup( + {t.device: t.merge_args for t in threads}) + merge_kwargs = values.regroup( + {t.device: t.merge_kwargs for t in threads}) + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) + for t in threads: + t.merge_result = values.select_device(t.device, merge_result) + finally: + for t in threads: + t.should_run.set() + coord.join(threads) + + return values.regroup({t.device: t.main_result for t in threads}) + + def map(self, map_over, fn, *args, **kwargs): + # TODO(josh11b): In eager mode, use one thread per device. + index = {} + i = 0 + for m in map_over: + d = self._devices[i % len(self._devices)] + with ops.device(d): + l = index.get(d, []) + l.append(fn(m, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs))) + index[d] = l + # TODO(josh11b): Need a values.regroup equivalent that handles MapOutput + # in addition to PerDevice data. + return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + + def configure(self, session_config=None): + if self._cross_tower_ops is None: + self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( + self._devices, session_config=session_config) + + def _get_cross_tower_ops(self): + if self._cross_tower_ops is None: + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) + return self._cross_tower_ops + + def _reduce(self, method_string, value, destinations): + if len(self._devices) == 1 and not isinstance(value, values.PerDevice): + value = values.PerDevice({self._devices[0]: value}) + assert isinstance(value, values.PerDevice) + + return self._get_cross_tower_ops().reduce( + method_string, value, destinations=destinations) + + def _batch_reduce(self, method_string, value_destination_pairs): + return self._get_cross_tower_ops().batch_reduce(method_string, + value_destination_pairs) + + def _update(self, var, fn, *args, **kwargs): + # TODO(josh11b): Also support TowerLocalVariables here? If so, args and + # kwargs don't need to be mirrored. + assert isinstance(var, values.MirroredVariable) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d, v in var._index.items(): # pylint: disable=protected-access + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(v, + *values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + assert isinstance(colocate_with, list) + # TODO(josh11b): In eager mode, use one thread per device. + updates = {} + for d in colocate_with: + name = "update_%d" % self._device_index.get(d) + with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + updates[d] = fn(*values.select_device_mirrored(d, args), + **values.select_device_mirrored(d, kwargs)) + return values.regroup(updates, values.Mirrored) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + assert isinstance(destination, six.string_types) + if isinstance(val, values.TowerLocalVariable): + val = self.reduce(val.reduce_method, val, destinations=destination) + with ops.device(destination): + return fn(self.unwrap(val)[0]) + + assert isinstance(val, values.Mirrored), ( + "val = %s (type %s)" % (val, val.__class__.__name__)) + if val.on_device(destination): + with ops.device(destination): + # Use an identity here to make sure we are returning a tensor + # instead of e.g. a variable object. + return array_ops.identity(fn(val.get(destination))) + device = None + for d in self._devices: + if val.on_device(d): + device = d + break + assert device is not None, ( + "Could not find destination %s in list of devices %s." % + (destination, val.devices)) + with ops.device(device): + v = fn(val.get(device)) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, val): + if isinstance(val, values.DistributedValues): + # Return in a deterministic order. + if set(val.devices) == self._canonical_device_set: + return [val.get(device=d) for d in self._devices] + return [val.get(device=d) for d in sorted(val.devices)] + return [val] + + @property + def is_single_tower(self): + return len(self._devices) == 1 + + @property + def num_towers(self): + return len(self._devices) + + def _worker_device_index(self): + return self._device_index + + @property + def worker_devices(self): + # Make a copy to prevent users from accidentally mutating our copy. + return list(self._devices) + + @property + def parameter_devices(self): + return list(self._devices) + + def non_slot_devices(self, var_list): + del var_list + return list(self._devices) + + def _get_devices_from(self, colocate_with=None): + if colocate_with is None: + return self._devices + elif isinstance(colocate_with, values.DistributedValues): + # pylint: disable=protected-access + return list(colocate_with._index.keys()) + elif isinstance(colocate_with, six.string_types): + return [colocate_with] + else: + return colocate_with + + class _MirroredTowerThread(threading.Thread): + """A thread that runs() a function on a device.""" + + def __init__(self, dist, coord, device, variable_creator_fn, fn, *args, + **kwargs): + super(MirroredStrategy._MirroredTowerThread, self).__init__() # pylint: disable=protected-access + self.coord = coord + self.distribution = dist + self.device = device + self.tower_id = dist.worker_devices.index(device) + self.variable_creator_fn = variable_creator_fn + # State needed to run and return the results of `fn`. + self.main_fn = fn + self.main_args = args + self.main_kwargs = kwargs + self.main_result = None + self.done = False + # State needed to run the next merge_call() (if any) requested via + # TowerContext. + self.merge_fn = None + self.merge_args = None + self.merge_kwargs = None + self.merge_result = None + # We use a thread.Event for the main thread to signal when this + # thread should start running (`should_run`), and another for + # this thread to transfer control back to the main thread + # (`has_paused`, either when it gets to a + # `get_tower_context().merge_call` or when `fn` returns). In + # either case the event starts cleared, is signaled by calling + # set(). The receiving thread waits for the signal by calling + # wait() and then immediately clearing the event using clear(). + self.should_run = threading.Event() + self.has_paused = threading.Event() + # These fields have to do with inheriting various contexts from the + # parent thread: + # pylint: disable=protected-access + self.context_mode = context.context()._eager_context.mode + if not context.context()._context_handle: + context.context()._initialize_handle_and_devices() + self.context_device_policy = ( + pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy( + context.context()._context_handle)) + self.graph = ops.get_default_graph() + self._variable_creator_stack = self.graph._variable_creator_stack[:] + self._captured_var_scope = variable_scope.get_variable_scope() + # Adding a "/" at end lets us re-enter this scope later. + self._captured_name_scope = self.graph.get_name_scope() + if self._captured_name_scope: + self._captured_name_scope += "/" + if self.tower_id > 0: + if not self._captured_name_scope: + self._captured_name_scope = "" + self._captured_name_scope += "tower_%d/" % self.tower_id + + def run(self): + # pylint: disable=protected-access + self.graph._variable_creator_stack = self._variable_creator_stack + self.should_run.wait() + self.should_run.clear() + try: + if self.coord.should_stop(): + return + with self.coord.stop_on_exception(), \ + context.context()._mode(self.context_mode), \ + context.context().device_policy(self.context_device_policy), \ + self.graph.as_default(), \ + MirroredTowerContext(self.distribution, self.tower_id), \ + ops.device(self.device), \ + ops.name_scope(self._captured_name_scope), \ + variable_scope.variable_scope( + self._captured_var_scope, reuse=self.tower_id > 0), \ + variable_scope.variable_creator_scope(self.variable_creator_fn): + self.main_result = self.main_fn(*self.main_args, **self.main_kwargs) + self.done = True + finally: + self.has_paused.set() + + +class MirroredTowerContext(distribute_lib.TowerContext): + """TowerContext used in MirroredStrategy.call_for_each_tower(). + + Opened in `_MirroredTowerThread`, to allow the user to invoke + `MirroredStrategy`'s specific implementation of `merge_call()`, + which works by delegating the function and its arguments to + the main thread (the one that invoked + `MirroredStrategy.call_for_each_tower()`). + """ + + def _merge_call(self, fn, *args, **kwargs): + """Delegate to the main thread to actually perform merge_call().""" + t = threading.current_thread() # a _MirroredTowerThread + t.merge_fn = fn + t.merge_args = args + t.merge_kwargs = kwargs + t.has_paused.set() + t.should_run.wait() + t.should_run.clear() + if t.coord.should_stop(): + raise _RequestedStop() + return t.merge_result + + @property + def device(self): + distribute_lib.require_tower_context(self) + return self._distribution_strategy.worker_devices[self._tower_id] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9f06da8e2ed185c2c32f79a5a4f5407165fb1d --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -0,0 +1,435 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multi-GPU tests for MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib + +GPU_TEST = "test_gpu" in sys.argv[0] + + +class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + if GPU_TEST: + self.assertGreater(context.num_gpus(), 0) + if context.num_gpus() > 1: + devices = ["/device:GPU:0", "/device:GPU:1"] + print(self.id().split(".")[-1], "devices:", ", ".join(devices)) + return mirrored_strategy.MirroredStrategy(devices) + + def testMinimizeLossEager(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + soft_placement = not GPU_TEST + print("testMinimizeLossGraph soft_placement:", soft_placement) + self._test_minimize_loss_graph( + self._get_distribution_strategy(), soft_placement=soft_placement) + + def testMapReduce(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_tower_id(self._get_distribution_strategy()) + + def testNumTowers(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self.assertEqual(2, self._get_distribution_strategy().num_towers) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testRunRegroupError(self): + + def run_fn(device_id): + # Generates a list with different lengths on different devices. + # Will fail in _regroup() (if more than one device). + return list(range(device_id)) + + dist = self._get_distribution_strategy() + with dist.scope(), self.assertRaises(AssertionError): + dist.call_for_each_tower(run_fn, dist.worker_device_index) + + @test_util.run_in_graph_and_eager_modes() + def testReduceToCpu(self): + if not GPU_TEST: + self.skipTest("Not GPU test") + + def run_fn(device_id): + return device_id + + dist = self._get_distribution_strategy() + with dist.scope(): + result = dist.call_for_each_tower(run_fn, dist.worker_device_index) + reduced = dist.reduce("sum", result, destinations="/device:CPU:0") + unwrapped = dist.unwrap(reduced) + self.assertEqual(1, len(unwrapped)) + expected = sum(range(len(dist.worker_devices))) + self.assertEqual(expected, self.evaluate(unwrapped[0])) + + +@test_util.with_c_api +class MirroredStrategyVariableCreationTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _skip_eager_if_gpus_less_than(self, num_gpus): + if context.num_gpus() < num_gpus and context.executing_eagerly(): + self.skipTest("Enough GPUs not available for this test in eager mode.") + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSingleVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + # This variable should be created only once across the threads because of + # special variable_creator functions used by `dist.call_for_each_tower`. + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnnamedVariable(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v = variable_scope.variable(1.0) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # Default name of "Variable" will be used. + self.assertEquals("Variable:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariables(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + for i in range(5): + vs.append(variable_scope.variable(1.0, name="foo" + str(i))) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for i, v in enumerate(result): + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals("foo" + str(i) + ":0", v.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleVariablesWithSameCanonicalName(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + vs = [] + vs.append(variable_scope.variable(1.0, name="foo/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar")) + vs.append(variable_scope.variable(1.0, name="foo_1/bar_1")) + vs.append(variable_scope.variable(1.0, name="foo/bar_1")) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return vs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + for v in result: + self.assertIsInstance(v, values.MirroredVariable) + self.assertEquals(4, len(result)) + self.assertEquals("foo/bar:0", result[0].name) + self.assertEquals("foo_1/bar:0", result[1].name) + self.assertEquals("foo_1/bar_1:0", result[2].name) + self.assertEquals("foo/bar_1:0", result[3].name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableWithSameCanonicalNameAcrossThreads(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(device_id): + v = variable_scope.variable(1.0, name="foo_" + str(device_id)) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + # The resulting mirrored variable will use the name from the first device. + self.assertEquals("foo_0:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithLayers(self): + self._skip_eager_if_gpus_less_than(1) + def model_fn(features): + with variable_scope.variable_scope("common"): + layer1 = core.Dense(1) + layer1(features) + layer2 = core.Dense(1) + layer2(features) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + layer3 = core.Dense(1) + layer3(features) + return [(layer1.kernel, layer1.bias), + (layer2.kernel, layer2.bias), + (layer3.kernel, layer3.bias)] + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + features = dist.distribute_dataset(features).get_next() + + with dist.scope(): + result = dist.call_for_each_tower( + model_fn, features, run_concurrently=False) + suffixes = ["", "_1", "_2"] + for (kernel, bias), suffix in zip(result, suffixes): + self.assertIsInstance(kernel, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/kernel:0", kernel.name) + self.assertIsInstance(bias, values.MirroredVariable) + self.assertEquals("common/dense" + suffix + "/bias:0", bias.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testWithGetVariableAndVariableScope(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(): + v0 = variable_scope.get_variable("var-thread0", [1]) + with variable_scope.variable_scope("common"): + v1 = variable_scope.get_variable("var-thread1", [1]) + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + v2 = variable_scope.get_variable("var-thread2", [1]) + + return v0, v1, v2 + + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + with dist.scope(): + with variable_scope.variable_scope("main"): + v = variable_scope.get_variable("var-main0", [1]) + self.assertEquals("main/var-main0:0", v.name) + + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(3, len(result)) + v0, v1, v2 = result + self.assertIsInstance(v0, values.MirroredVariable) + self.assertEquals("main/var-thread0:0", v0.name) + self.assertIsInstance(v1, values.MirroredVariable) + self.assertEquals("main/common/var-thread1:0", v1.name) + self.assertIsInstance(v2, values.MirroredVariable) + self.assertEquals("main/common/var-thread2:0", v2.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testThreeDevices(self): + self._skip_eager_if_gpus_less_than(2) + + def model_fn(): + v = variable_scope.variable(1.0, name="foo") + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1", "/device:CPU:0"]) + + with dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertIsInstance(result, values.MirroredVariable) + self.assertEquals("foo:0", result.name) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testNonMatchingVariableCreation(self): + self._skip_eager_if_gpus_less_than(1) + + def model_fn(name): + v = variable_scope.variable(1.0, name=name) + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + names = values.DistributedValues({ + "/device:CPU:0": "foo", + "/device:GPU:0": "bar" + }) + with self.assertRaises(RuntimeError): + _ = dist.call_for_each_tower(model_fn, names, run_concurrently=False) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariable(self): + self._skip_eager_if_gpus_less_than(1) + + all_v_sum = {} + all_v_mean = {} + + def model_fn(device_id): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope("sum"): + v_sum = variable_scope.variable(1.0) + with tower_context.tower_local_var_scope("mean"): + v_mean = variable_scope.variable(4.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + self.assertTrue(isinstance(v_mean, values.TowerLocalVariable)) + updates = [v_sum.assign_add(2.0 + device_id), + v_mean.assign(6.0 * device_id)] + all_v_sum[device_id] = v_sum + all_v_mean[device_id] = v_mean + return updates, v_sum, v_mean + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with dist.scope(): + # Create "sum" and "mean" versions of TowerLocalVariables. + ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False) + # Should see the same wrapping instance in all towers. + self.assertIs(all_v_sum[0], ret_v_sum) + self.assertIs(all_v_mean[0], ret_v_mean) + for i in range(1, dist.num_towers): + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Apply updates + self.evaluate(variables.global_variables_initializer()) + self.evaluate([y for x in ret_ops for y in dist.unwrap(x)]) + expected_sum = 0.0 + expected_mean = 0.0 + for i, d in enumerate(dist.worker_devices): + # Test access within a device scope, should see different values. + with ops.device(d): + v_sum_value = self.evaluate(ret_v_sum.read_value()) + v_mean_value = self.evaluate(ret_v_mean.read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected + + # fetch() should return the value you get by applying the + # reduction across all towers. + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + expected_mean /= len(dist.worker_devices) + self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + + # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not + # testing this in eager mode. + + def testNameScope(self): + def model_fn(): + with ops.name_scope("foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(1.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("main/foo/" + name + ":0", v0.name) + self.assertEquals("main/tower_1/foo/" + name + ":0", v1.name) + + def testWithDefaultName(self): + def model_fn(): + with ops.name_scope(None, "foo"): + a = constant_op.constant(1.0, name="a") + distribute_lib.get_tower_context().merge_call(lambda _: _) + b = constant_op.constant(2.0, name="b") + return a, b + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + self.assertEquals(2, len(result)) + for v, name in zip(result, ["a", "b"]): + self.assertIsInstance(v, values.DistributedValues) + v0, v1 = dist.unwrap(v) + self.assertEquals("foo/" + name + ":0", v0.name) + self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ef0ecc77a8e8432dfa4eb6da7c324b371dab70 --- /dev/null +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -0,0 +1,91 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class MirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import distribute as distribute_lib + + +@test_util.with_c_api +class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return mirrored_strategy.MirroredStrategy(["/device:CPU:0"]) + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +@test_util.with_c_api +class VariableCreatorStackTest(test.TestCase): + + def testCreatorStacksAreThreadLocal(self): + devices = ["/device:CPU:0", "/device:GPU:0"] + dist = mirrored_strategy.MirroredStrategy(devices) + + def model_fn(device_id): + assert isinstance(device_id, int) + def thread_creator_fn(next_creator, *args, **kwargs): + return next_creator(*args, **kwargs) + ":thread_" + str(device_id) + + with variable_scope.variable_creator_scope(thread_creator_fn): + # Create a variable in this scope. + v = variable_scope.variable(1.0) + + # This will pause the current thread, and execute the other thread. + distribute_lib.get_tower_context().merge_call(lambda _: _) + return v + + def main_thread_creator(next_creator, *args, **kwargs): + # We are not using the underlying next_creator for test purposes. + del next_creator, args, kwargs + return "main_thread" + + with context.graph_mode(), \ + dist.scope(), \ + variable_scope.variable_creator_scope(main_thread_creator): + result = dist.call_for_each_tower(model_fn, dist.worker_device_index) + result = dist.unwrap(result) + expected = ["main_thread:thread_0", "main_thread:thread_1"] + self.assertEquals(expected, result) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..7644acedc99361d7287a91832d76bc68cbc6ac0a --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Monitor is responsible for training, checkpointing and recovery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.ops import variables + + +class Monitor(object): + """Executes training steps, recovers and checkpoints. + + Note that this class is particularly preliminary, experimental, and + expected to change. + """ + # TODO(isaprykin): Support step functions that need multiple session calls. + # TODO(isaprykin): Support extra arguments to the step function. + # TODO(isaprykin): Support recovery, checkpointing and summaries. + + def __init__(self, step_callable, session=None): + """Initialize the Monitor with components for executing training steps. + + Args: + step_callable: a training `Step` that's capable of signaling when done. + session: a `Session` instance that's needed for graph mode. + + Raises: + ValueError: if `session` was provided for eager mode or not provided for + graph mode. + """ + if context.executing_eagerly(): + if session is not None: + raise ValueError("Should not provide a `session` in Eager mode.") + self._run_step = step_callable + else: + if session is None: + raise ValueError("Should provide a `session` in Graph mode.") + self._run_step = session.make_callable(step_callable()) + session.run(variables.global_variables_initializer()) + + def run_steps(self, num_steps=None): + step = 0 + while num_steps is None or step < num_steps: + try: + self._run_step() + step += 1 + except errors.OutOfRangeError: + break diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8277e1e7919e86ef616b31d0986589dcc9c49bbd --- /dev/null +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Monitor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import monitor as monitor_lib +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.training import gradient_descent + + +class MonitorTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example(optimizer_fn, distribution) + + if context.executing_eagerly(): + monitor = monitor_lib.Monitor(single_loss_step, None) + else: + with self.test_session() as sess: + monitor = monitor_lib.Monitor(single_loss_step, sess) + + monitor.run_steps(1) + + self.assertEqual(1, len(layer.trainable_variables)) + mirrored_weight_variable = layer.trainable_variables[0] + start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = abs(numpy.array(start_error) - 1) + + monitor.run_steps(9) + end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = abs(numpy.array(end_error) - 1) + self.assertGreaterEqual(start_error, end_error) + + def testPassingASessionInEager(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with self.test_session() as sess: + with self.assertRaisesRegexp(ValueError, "Should not provide"): + _ = monitor_lib.Monitor(step_function, sess) + + def testNotPassingASessionInGraph(self): + distribution = one_device_strategy.OneDeviceStrategy( + "/device:CPU:0") + step_function, _ = single_loss_example( + lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) + + with context.graph_mode(), ops.Graph().as_default(): + with self.assertRaisesRegexp(ValueError, "Should provide"): + _ = monitor_lib.Monitor(step_function, session=None) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..39c49442b9c3245cfd0b67a51be68773a6fd3ff4 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -0,0 +1,148 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class OneDeviceStrategy implementing DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import distribute as distribute_lib + + +# TODO(josh11b): Replace asserts in this file with if ...: raise ... + + +class OneDeviceStrategy(distribute_lib.DistributionStrategy): + """A distribution strategy for running on a single device.""" + # TODO(josh11b): Do we wrap values in types to generate errors if you are + # doing something that won't work with other DistributionStrategy + # implementations? + + def __init__(self, device): + super(OneDeviceStrategy, self).__init__() + self._device = device + + def _create_variable(self, next_creator, *args, **kwargs): + # No need to distinguish tower-local variables when not mirroring, + # we just enforce that they are not trainable. + if kwargs.pop("tower_local_reduce_method", None) is not None: + kwargs["trainable"] = False + + colocate_with = kwargs.pop("colocate_with", None) + if colocate_with is None: + with ops.device(self._device): + return next_creator(*args, **kwargs) + if isinstance(colocate_with, six.string_types): + with ops.device(colocate_with): + return next_creator(*args, **kwargs) + if (isinstance(colocate_with, list) and len(colocate_with) == 1 and + isinstance(colocate_with[0], six.string_types)): + with ops.device(colocate_with[0]): + return next_creator(*args, **kwargs) + with ops.colocate_with(colocate_with): + return next_creator(*args, **kwargs) + + def distribute_dataset(self, dataset): + if context.executing_eagerly(): + return datasets.Iterator(dataset) + else: + return dataset.make_one_shot_iterator() + + def _broadcast(self, tensor, destinations): + return tensor + + def _call_for_each_tower(self, fn, *args, **kwargs): + # We don't run `fn` in multiple threads in OneDeviceStrategy. + kwargs.pop("run_concurrently", None) + with ops.device(self._device), _OneDeviceTowerContext(self): + return fn(*args, **kwargs) + + def map(self, map_over, fn, *args, **kwargs): + with ops.device(self._device): + return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) + + def _reduce(self, method_string, value, destinations): + if not isinstance(value, values.MapOutput): + return value + l = value.get() + assert l + with ops.device(self._device): + if method_string == "sum": + return math_ops.add_n(l) + elif method_string == "mean": + return math_ops.add_n(l) / len(l) + else: + assert False + + def _update(self, var, fn, *args, **kwargs): + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(var, *args, **kwargs) + + def _update_non_slot(self, colocate_with, fn, *args, **kwargs): + del colocate_with + with ops.device(self._device), distribute_lib.UpdateContext(self._device): + return fn(*args, **kwargs) + + def _fetch(self, val, destination, fn): + """Return a copy of `val` or `fn(val)` on `destination`.""" + with ops.device(self._device): + v = fn(val) + with ops.device(destination): + return array_ops.identity(v) + + def _unwrap(self, value): + return [value] + + @property + def is_single_tower(self): + return True + + @property + def num_towers(self): + return 1 + + @property + def worker_devices(self): + return [self._device] + + @property + def parameter_devices(self): + return [self._device] + + def non_slot_devices(self, var_list): + del var_list + return [self._device] + + def _worker_device_index(self): + return 0 + + +class _OneDeviceTowerContext(distribute_lib.TowerContext): + + def __init__(self, distribution_strategy): + distribute_lib.TowerContext.__init__( + self, distribution_strategy, tower_id=0) + + @property + def device(self): + return self._distribution_strategy.worker_devices[0] diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7101ed0756f44b846f10ddc6d429afe005a2f196 --- /dev/null +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -0,0 +1,54 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class OneDeviceStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +@test_util.with_c_api +class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return one_device_strategy.OneDeviceStrategy("/device:CPU:0") + + def testMinimizeLossEager(self): + self._test_minimize_loss_eager(self._get_distribution_strategy()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + def testMapReduce(self): + self._test_map_reduce(self._get_distribution_strategy()) + + def testDeviceIndex(self): + self._test_device_index(self._get_distribution_strategy()) + + def testTowerId(self): + self._test_tower_id(self._get_distribution_strategy()) + + @test_util.run_in_graph_and_eager_modes() + def testCallAndMergeExceptions(self): + self._test_call_and_merge_exceptions(self._get_distribution_strategy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0912b625f44342d22acc0ce9bb52a6b632c75a0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for running legacy optimizer code with DistributionStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables + + +class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v2_optimizers(), + combinations.combine(mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + def testTrainNetwork(self, distribution, optimizer_fn, + use_callable_loss=True): + with distribution.scope(): + model_fn, dataset, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) + + iterator = distribution.distribute_dataset(dataset) + + def run_step(): + return control_flow_ops.group(distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), run_concurrently=layer.built))) + + if not context.executing_eagerly(): + with self.test_session() as sess: + run_step = sess.make_callable(run_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcbb8568f92ebabbeeedb45ee677e4ee23d77dc --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -0,0 +1,168 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extension of prefetching_ops to support more than one device.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest as data_nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.util import nest + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for @{tf.data.Iterator} that prefetches to another device.""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._devices = devices + input_iterator = input_dataset.make_one_shot_iterator() + input_iterator_handle = input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, input_iterator.output_types, input_iterator.output_shapes, + input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + target_device = gen_dataset_ops.iterator_get_device( + input_iterator._iterator_resource) + self._buffering_resources = [] + for device in nest.flatten(self._devices): + with ops.device(device): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_prefetch_fn, + target_device=target_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size) + self._buffering_resources.append(buffer_resource_handle) + + def get_next(self, name=None): + """See @{tf.data.Iterator.get_next}.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_result = [] + # TODO(priyag): This will fail if the input size (typically number of + # batches) is not divisible by number of devices. + # How do we handle that more gracefully / let the user know? + for buffer_resource in self._buffering_resources: + flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + buffer_resource, + output_types=data_nest.flatten(sparse.as_dense_types( + self.output_types, self.output_classes)), name=name) + + ret = sparse.deserialize_sparse_tensors( + data_nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + flat_result.append(ret) + + return nest.pack_sequence_as(self._devices, flat_result) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.Dataset): + """A `Dataset` whose iterator prefetches elements to other device(s).""" + + def __init__(self, input_dataset, devices, buffer_size): + self._input_dataset = input_dataset + self._devices = devices + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator(self._input_dataset, self._devices, + self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + raise NotImplementedError("`prefetch_to_devices()` is not currently " + "compatible with initializable iterators. Use " + "`make_one_shot_iterator()` instead.") + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_devices()` must be the last " + "transformation in a dataset pipeline.") + + # TODO(priyag): Fix the output types, shapes and classes to match the result + # of get_next (which has the additional nesting layer of devices now). + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_devices(devices, buffer_size=None): + """A transformation that prefetches dataset values to the given `devices`. + + NOTE: Although the transformation creates a @{tf.data.Dataset}, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + devices: A nested structure of devices on which to prefetch the data. It can + be a single device name, or a tuple or list of device names. + buffer_size: (Optional.) The number of elements to buffer on each device. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, devices, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed16f4607881f2864479c04b4c25e95d9fa1850 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class PrefetchingOpsV2Test(test.TestCase): + + def testPrefetchToOneDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesInAList(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + output = [] + with self.test_session() as sess: + for _ in range(5): + result = sess.run(next_element) + self.assertEqual(2, len(result)) + output.extend(result) + self.assertEquals(set(range(10)), set(output)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator.py b/tensorflow/contrib/distribute/python/shared_variable_creator.py new file mode 100644 index 0000000000000000000000000000000000000000..a7083e279f20803b227dcd52f6420ae832aa2df4 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to re-use variables created on first device on subsequent devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +_VARIABLE_UNIQUIFYING_REGEX = re.compile(r"_\d/") +_VARIABLE_UNIQUIFYING_REGEX_AT_END = re.compile(r"_\d$") + + +def _canonicalize_variable_name(name): + # If no name is specified, uses default name "Variable". + if name is None: + return "Variable" + # Replace all instances of "_/" with "/" + name = _VARIABLE_UNIQUIFYING_REGEX.sub("/", name) + # Replace any instances of "_" at the end of the string with "" + name = _VARIABLE_UNIQUIFYING_REGEX_AT_END.sub("", name) + return name + + +def make_fn(shared_variable_store, device_id): + """Construct the variable creator function for device `device_id`. + + Constructs custom variable creator functions for the given device. + On first device (device_id == 0), it creates the variable using the + `next_creator`, and stores it in the provided `shared_variable_store`. + On all other devices (device_id > 0), it tries to re-use the variable + already created with the same name. If no such variable exists, it throws an + error. + Additionally, we de-uniquify variable names before checking for matches. This + helps re-use variables which are intended to be the same but have different + names due to variable uniquification happening upstream. Since this might + mean we may have multiple variables with the same canonical name, we store + them in a list per canonical name and return them in the same order as well. + + Args: + shared_variable_store: A dictionary that we will use to store variables + created on the first device, and re-used by creators for other devices. + device_id: Integer index of the device whose creator should be + constructed. + + Returns: + An appropriate creator function based on device_id. + + """ + variable_scope_access_index = {} + assert isinstance(device_id, int) + + def create_new_variable(next_creator, *args, **kwargs): + """Create the variable using `next_creator` and store it.""" + canonical_name = _canonicalize_variable_name(kwargs.get("name")) + v = next_creator(*args, **kwargs) + + if canonical_name not in shared_variable_store: + shared_variable_store[canonical_name] = [] + shared_variable_store[canonical_name].append(v) + return v + + def reuse_variable(next_creator, *args, **kwargs): + """Re-use existing variable from store with same name (in order).""" + del next_creator, args + name = kwargs.get("name") + canonical_name = _canonicalize_variable_name(name) + + try: + variable_index = variable_scope_access_index.get(canonical_name, 0) + v = shared_variable_store[canonical_name][variable_index] + # TODO(priyag): Make this variable re-use more robust by adding checks + # that the requested shape and dtype match the existing variable. + variable_scope_access_index[canonical_name] = variable_index + 1 + return v + except (KeyError, IndexError): + raise RuntimeError( + "Tried to create variable {} with mismatching name on device {}". + format(name, device_id)) + + if device_id == 0: + return create_new_variable + else: + return reuse_variable diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..713494d603b855be2863af9f24ab98d4cf048042 --- /dev/null +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -0,0 +1,75 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SharedVariableCreator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import shared_variable_creator +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variable_scope + + +class CanonicalizeVariableNameTest(test.TestCase): + + def _canonicalize(self, name): + return shared_variable_creator._canonicalize_variable_name(name) + + def testNoName(self): + self.assertEquals("Variable", self._canonicalize(None)) + + def testPatternInMiddle(self): + self.assertEquals("foo/bar/baz", self._canonicalize("foo_1/bar_1/baz")) + + def testPatternAtEnd(self): + self.assertEquals("foo", self._canonicalize("foo_1")) + + def testWrongPatterns(self): + self.assertEquals("foo_1:0", self._canonicalize("foo_1:0")) + self.assertEquals("foo1", self._canonicalize("foo1")) + self.assertEquals("foo_a", self._canonicalize("foo_a")) + + +@test_util.with_c_api +class SharedVariableCreatorTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSharedVariable(self): + + shared_variable_store = {} + num_devices = 3 + creator_fns = [] + for i in range(num_devices): + creator_fn = shared_variable_creator.make_fn(shared_variable_store, i) + creator_fns.append(creator_fn) + + with variable_scope.variable_creator_scope(creator_fns[0]): + v0 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[1]): + v1 = variable_scope.variable(1.0, name="foo") + + with variable_scope.variable_creator_scope(creator_fns[2]): + v2 = variable_scope.variable(1.0, name="foo") + + # v1 and v2 should be same as v0 + self.assertIs(v1, v0) + self.assertIs(v2, v0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py new file mode 100644 index 0000000000000000000000000000000000000000..cef5fd2f8943d348a0721cd72032bf6cb2199ad9 --- /dev/null +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A simple network to use in tests and examples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import step_fn +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.layers import core +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def single_loss_example(optimizer_fn, distribution, use_bias=False): + """Build a very simple network to use in tests and examples.""" + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + optimizer = optimizer_fn() + layer = core.Dense(1, use_bias=use_bias) + + def loss_fn(x): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer, + distribution) + + # Layer is returned for inspecting the kernels in tests. + return single_loss_step, layer + + +def minimize_loss_example(optimizer_fn, + use_bias=False, + use_callable_loss=True, + create_optimizer_inside_model_fn=False): + """Example of non-distribution-aware legacy code.""" + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # An Optimizer instance is created either outside or inside model_fn. + outer_optimizer = None + if not create_optimizer_inside_model_fn: + outer_optimizer = optimizer_fn() + + layer = core.Dense(1, use_bias=use_bias) + + def model_fn(x): + """A very simple model written by the user.""" + + def loss_fn(): + y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + return y * y + + optimizer = outer_optimizer or optimizer_fn() + + if use_callable_loss: + return optimizer.minimize(loss_fn) + else: + return optimizer.minimize(loss_fn()) + + return model_fn, dataset, layer + + +def batchnorm_example(optimizer_fn, + batch_per_epoch=1, + momentum=0.9, + renorm=False): + """Example of non-distribution-aware legacy code with batch normalization.""" + # input shape is [16, 8], input values are increasing in both dimensions. + dataset = dataset_ops.Dataset.from_tensor_slices( + [[[float(x * 8 + y + z * 100) + for y in range(8)] + for x in range(16)] + for z in range(batch_per_epoch)]).repeat() + optimizer = optimizer_fn() + batchnorm = normalization.BatchNormalization( + renorm=renorm, momentum=momentum, fused=False) + + def model_fn(x): + + def loss_fn(): + y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) + loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + return loss + + # Callable loss. + return optimizer.minimize(loss_fn) + + return model_fn, dataset, batchnorm diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..82514c64be40b421c4a9887932f2cfb8e1ac4be0 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -0,0 +1,103 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The step function abstraction represents a single training step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import backprop +from tensorflow.python.training import optimizer as optimizer_lib + + +class Step(object): + """Interface for performing each step of a training algorithm.""" + + def __init__(self, distribution): + self._distribution = distribution + + @property + def distribution(self): + return self._distribution + + def __call__(self): + """Perform one step of this training algorithm.""" + return self.step(self.inputs()) + + def inputs(self): + """For the generating the input to be passed to `step()`.""" + raise NotImplementedError("must be implemented in descendants") + + def step(self, inputs): + """Perform the main computation of this training algorithm.""" + raise NotImplementedError("must be implemented in descendants") + + +class StandardInputStep(Step): + """Step with a standard implementation of input handling. + + Args: + input_dataset: a tf.data Dataset that provides input. + """ + + def __init__(self, input_dataset, distribution): + Step.__init__(self, distribution) + self._distributed_input = distribution.distribute_dataset(input_dataset) + + def inputs(self): + return self._distributed_input.get_next() + + +class StandardSingleLossStep(StandardInputStep): + """A step function that implements a training step for a feed forward network. + + An instance of this class is intended to be used as a callable: + + ```python + ... + step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer) + step.initialize(distribution) + + # Run a single training step on a given DistributionStrategy: + step(distribution) + ... + ``` + + Args: + input_dataset: a tf.data Dataset that provides input. + loss_fn: a function that returns loss. + optimizer: an optimizer that implements an update rule. + distribution: a `DistributionStrategy` object. + """ + + def __init__(self, input_dataset, loss_fn, optimizer, distribution): + StandardInputStep.__init__(self, input_dataset, distribution) + self._loss_fn = loss_fn + self._optimizer = optimizer + self._is_run_concurrently = False + + def step(self, inputs): + with self._distribution.scope(): + gradients_fn = backprop.implicit_grad(self._loss_fn) + gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn) + + grads_and_vars = self.distribution.call_for_each_tower( + gradients_fn, inputs, run_concurrently=self._is_run_concurrently) + # If threads use layers, then we need to run the first step sequentially, + # so that layers.build() is not executed in parallel. Otherwise, multiple + # sets of mirrored variables are going to be created. + self._is_run_concurrently = True + return self._optimizer._distributed_apply( # pylint: disable=protected-access + self.distribution, grads_and_vars) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75c5ec9659d193e77d219ba79977615d58841d64 --- /dev/null +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Step.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.ops import variables + + +class SingleLossStepTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine(mode=combinations.graph_and_eager_modes))) + def testTrainNetwork(self, distribution, optimizer_fn): + with distribution.scope(): + single_loss_step, layer = single_loss_example( + optimizer_fn, distribution, use_bias=True) + + if context.executing_eagerly(): + run_step = single_loss_step + else: + with self.test_session() as sess: + run_step = sess.make_callable(single_loss_step()) + self.evaluate(variables.global_variables_initializer()) + + weights, biases = [], [] + for _ in range(10): + run_step() + + weights.append(self.evaluate(distribution.fetch(layer.kernel))) + biases.append(self.evaluate(distribution.fetch(layer.bias))) + + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) + is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) + self.assertTrue(is_not_increasing) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4ad9f146bc1d6a987fbeecbb05122946137154 --- /dev/null +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -0,0 +1,225 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for testing DistributionStrategy descendants.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer + + +class _TestException(Exception): + pass + + +# May be the argument to either distribution.call_for_each_tower() or +# get_tower_context().merge_call() +def _raise_exception_fn(_=None): + raise _TestException() + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that raises an exception. +def _merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_raise_exception_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that raises an exception. +def _call_raises_fn(dist): + dist.call_for_each_tower(_raise_exception_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, +# calls a get_tower_context().merge_call() that calls a +# call_for_each_tower() that raises an exception. +def _merge_call_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_raises_fn) + + +# Must be the argument to a get_tower_context().merge_call() call, calls +# dist.call_for_each_tower() with a function that calls a +# get_tower_context().merge_call() that raises an exception. +def _call_merge_raises_fn(dist): + dist.call_for_each_tower(_merge_raises_fn) + + +# Must be the argument to a distribution.call_for_each_tower() call, calls a +# get_tower_context().merge_call() that calls a call_for_each_tower() that +# calls a get_tower_context().merge_call() that raises an exception. +def _merge_call_merge_raises_fn(): + distribute_lib.get_tower_context().merge_call(_call_merge_raises_fn) + + +class DistributionTestBase(test.TestCase): + """Some tests that should work with any DistributionStrategy.""" + + def _test_minimize_loss_eager(self, d): + with d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a + # common `implicit_grad` function and put it in DistributionStrategy. + grad_fn = backprop.implicit_grad(loss) + grad_fn = optimizer.get_filtered_grad_fn(grad_fn) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one, run_concurrently=l.built) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + # control_dependencies irrelevant but harmless in eager execution + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + for i in range(10): + b, a = step() + if i == 0: + before, = b # pylint: disable=unbalanced-tuple-unpacking + after, = a # pylint: disable=unbalanced-tuple-unpacking + + error_before = abs(before.numpy() - 1) + error_after = abs(after.numpy() - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_minimize_loss_graph(self, d, soft_placement=False): + config = config_pb2.ConfigProto() + config.allow_soft_placement = soft_placement + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + with context.graph_mode(), \ + ops.Graph().as_default(), \ + self.test_session(config=config) as sess, \ + d.scope(): + l = core.Dense(1, use_bias=False) + + def loss(x): + # TODO(josh11b): What if this constant was instead a captured + # value? Would it need to be a value that has been passed + # through d.broadcast()? + y = array_ops.reshape(l(x), []) - constant_op.constant(1.) + return y * y + + grad_fn = backprop.implicit_grad(loss) + + def update(v, g): + return v.assign_sub(0.2 * g) + + one = d.broadcast(constant_op.constant([[1.]])) + + def step(): + """Perform one optimization step.""" + # Run forward & backward to get gradients, variables list. + g_v = d.call_for_each_tower(grad_fn, one) + + # Update the variables using the gradients and the update() function. + before_list = [] + after_list = [] + for g, v in g_v: + fetched = d.fetch(v) + before_list.append(fetched) + with ops.control_dependencies([fetched]): + g = d.reduce("sum", g, destinations=v) + with ops.control_dependencies(d.unwrap(d.update(v, update, g))): + after_list.append(d.fetch(v)) + return before_list, after_list + + before_out, after_out = step() + variables.global_variables_initializer().run() + for i in range(10): + b, a = sess.run((before_out, after_out)) + if i == 0: + before, = b + after, = a + + error_before = abs(before - 1) + error_after = abs(after - 1) + # Error should go down + self.assertLess(error_after, error_before) + + def _test_map_reduce(self, d, in_graph=None): + with d.scope(): + map_in = [constant_op.constant(i) for i in range(10)] + map_out = d.map(map_in, lambda x, y: x * y, 2) + observed = d.fetch(d.reduce("sum", map_out)) + expected = 90 # 2 * (0 + 1 + ... + 9) + self.assertEqual(expected, observed.numpy()) + + def _test_device_index(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(device_id): + self.assertLess(device_id, len(d.worker_devices)) + self.assertFalse(expected_devices[device_id]) + expected_devices[device_id] = True + + d.call_for_each_tower(mark_devices_fn, d.worker_device_index) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_tower_id(self, d): + with d.scope(): + expected_devices = [False] * len(d.worker_devices) + + def mark_devices_fn(): + tower_id = distribute_lib.get_tower_context().tower_id + self.assertLess(tower_id, len(d.worker_devices)) + self.assertFalse(expected_devices[tower_id]) + expected_devices[tower_id] = True + + d.call_for_each_tower(mark_devices_fn) + self.assertAllEqual(expected_devices, [True] * len(d.worker_devices)) + + def _test_call_and_merge_exceptions(self, dist): + with dist.scope(): + with self.assertRaises(_TestException): + dist.call_for_each_tower(_raise_exception_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_raises_fn) + with self.assertRaises(_TestException): + dist.call_for_each_tower(_merge_call_merge_raises_fn) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac307dd6a919179fda3af2448b39eaa59b811c6 --- /dev/null +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -0,0 +1,82 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU Distribution Strategy. + +This is experimental. It's not ready for general use. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops + + +# TODO(isaprykin): Consider whether inheriting is really appropriate. +class TpuStrategy(one_device_strategy.OneDeviceStrategy): + + def __init__(self, master=None, iterations=None, model_dir=None): + super(TpuStrategy, self).__init__('/cpu:0') + + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) + + # TODO(isaprykin): Give an API for many iterations per step. + iterations = 1 + + # TODO(isaprykin): Do not hard code shapes and input format :) + # TODO(isaprykin): Detect the number of TPU cores automatically. + + def dequeueing_fn(*args, **kwargs): + del args, kwargs + x, = tpu.infeed_dequeue_tuple(dtypes=[dtypes.float32], shapes=[[1, 1, 1]]) + return fn(x) + + iterator = args[0] + + def infeed_input(i): + """Get input, split it and then enqueue.""" + batches = iterator.get_next() + batches = array_ops.split(batches, 2) + + infeeds = [ + tpu_ops.infeed_enqueue_tuple( + inputs=[batches[j]], shapes=[[1, 1, 1]], device_ordinal=j) + for j in range(2) + ] + + with ops.control_dependencies(infeeds): + return i + 1 + + with ops.device('/task:0/device:CPU:0'): + enqueue_ops = control_flow_ops.while_loop( + lambda i: i < iterations, + infeed_input, [constant_op.constant(0)], + parallel_iterations=1) + + def iterate_on_tpu(): + return tpu.repeat(iterations, dequeueing_fn, []) + + with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access + tpu_result = tpu.batch_parallel(iterate_on_tpu, [], num_shards=2) + + return control_flow_ops.group(tpu_result, enqueue_ops) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py new file mode 100644 index 0000000000000000000000000000000000000000..87bf0590384cc74ca0f0575bcef4e84599a8b666 --- /dev/null +++ b/tensorflow/contrib/distribute/python/values.py @@ -0,0 +1,578 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various classes representing distributed values. + +See go/tf-distribution-strategy. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import weakref + +import six + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.contrib.eager.python import datasets +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import device_util +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import saver +from tensorflow.python.util import nest + + +# pylint: disable=line-too-long +# TODO(josh11b): Should device values be strings or DeviceSpec objects +# Not sure DeviceSpec objects are usable as a dict key. +class DistributedValues(object): + """Holds a map from device to values. Either PerDevice or Mirrored.""" + + def __init__(self, index): + self._index = {device_util.canonicalize(key): value + for key, value in six.iteritems(index)} + + def get(self, device=None): + """Returns the value for the current device or raises a ValueError.""" + if device is None: + tower_context = distribute_lib.get_tower_context() + if tower_context: + device = tower_context.device + else: + device = distribute_lib.get_update_device() + if device is None: + device = device_util.current() + device = device_util.canonicalize(device) + try: + return self._index[device] + except KeyError: + raise ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())) + + def on_device(self, device): + device = device_util.canonicalize(device) + return device in self._index + + @property + def devices(self): + return list(self._index.keys()) + + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + # TODO(josh11b): Possibly make an accessor for _index for use by + # DistributionStrategy implementations. + + +class DistributedDelegate(DistributedValues): + """A map from device to values; acts as the same type as the values.""" + + def __init__(self, index): + super(DistributedDelegate, self).__init__(index) + + def __getattr__(self, name): + return getattr(self.get(), name) + + # pylint: disable=multiple-statements + def __add__(self, o): return self.get() + o + def __radd__(self, o): return o + self.get() + def __sub__(self, o): return self.get() - o + def __rsub__(self, o): return o - self.get() + def __mul__(self, o): return self.get() * o + def __rmul__(self, o): return o * self.get() + def __truediv__(self, o): return self.get() / o + def __rtruediv__(self, o): return o / self.get() + def __floordiv__(self, o): return self.get() // o + def __rfloordiv__(self, o): return o // self.get() + def __mod__(self, o): return self.get() % o + def __rmod__(self, o): return o % self.get() + def __lt__(self, o): return self.get() < o + def __le__(self, o): return self.get() <= o + def __gt__(self, o): return self.get() > o + def __ge__(self, o): return self.get() >= o + def __and__(self, o): return self.get() & o + def __rand__(self, o): return o & self.get() + def __or__(self, o): return self.get() | o + def __ror__(self, o): return o | self.get() + def __xor__(self, o): return self.get() ^ o + def __rxor__(self, o): return o ^ self.get() + def __getitem__(self, o): return self.get()[o] + def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) + def __rpow__(self, o): return pow(o, self.get()) + def __invert__(self): return ~self.get() + def __neg__(self): return -self.get() + def __abs__(self): return abs(self.get()) + + def __div__(self, o): + try: + return self.get().__div__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rdiv__(self, o): + try: + return self.get().__rdiv__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __matmul__(self, o): + try: + return self.get().__matmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + def __rmatmul__(self, o): + try: + return self.get().__rmatmul__(o) + except AttributeError: + # See https://docs.python.org/3/library/constants.html#NotImplemented + return NotImplemented + + # TODO(josh11b): Even more operator overloads. + + +class PerDevice(DistributedValues): + """Holds a map from device to unsynchronized values.""" + pass + + +class Mirrored(DistributedValues): + """Holds a map from device to values which are kept in sync.""" + pass + + +def _assign_on_device(device, variable, tensor): + with ops.device(device): + return variable.assign(array_ops.identity(tensor)) + + +DistributedVarOp = collections.namedtuple( + "DistributedVarOp", ["name", "graph", "type"]) + + +class DistributedVariable(DistributedDelegate): + """Holds a map from device to variables.""" + # TODO(josh11b): Support changing the set of variables if e.g. if new + # devices are joining or a device is to leave. + + def __init__(self, index): + # Child class must set self._primary_var before calling + # super(...).__init__(index). + self._common_name = self._primary_var.name.split(":")[0] + super(DistributedVariable, self).__init__(index) + + @property + def initializer(self): + return control_flow_ops.group([v.initializer for v in self._index.values()]) + + @property + def graph(self): + return self._primary_var.graph + + @property + def _shared_name(self): + return self._common_name + + @property + def _unique_id(self): + return self._primary_var._unique_id # pylint: disable=protected-access + + @property + def name(self): + return self._primary_var.name + + @property + def dtype(self): + return self._primary_var.dtype + + @property + def shape(self): + return self._primary_var.shape + + def get_shape(self): + return self._primary_var.get_shape() + + def to_proto(self, export_scope=None): + return self._primary_var.to_proto(export_scope=export_scope) + + @property + def op(self): + # We want cross-tower code that does some var.op.X calls + # to work (even if the current device isn't in self.devices), but + # other uses of var.op in a cross-tower context to fail. + if distribute_lib.get_cross_tower_context(): + return DistributedVarOp(self._primary_var.op.name, + self._primary_var.op.graph, + self._primary_var.op.type) + return self.get().op + + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + + +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) +ops.register_dense_tensor_like_type(DistributedVariable) + + +class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): + """Class for defining how to restore a MirroredVariable.""" + + def __init__(self, mirrored_variable, primary_variable, name): + self._mirrored_variable = mirrored_variable + super(_MirroredSaveable, self).__init__(primary_variable, "", name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access + + +def _get_update_device(): + """Validate we are in update/update_non_slot() and return current device. + + This is used in MirroredVariable.assign* members, to make sure they + are only called via an update method, to make sure all components of the + variable are being updated in a consistent way. + + Returns: + A string device. + + Raises: + RuntimeError: If not in distribution.update()/.update_non_slot(). + """ + device = distribute_lib.get_update_device() + if device is None: + raise RuntimeError( + "Use DistributionStrategy.update() to modify a MirroredVariable.") + return device + + +class MirroredVariable(DistributedVariable, Mirrored, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are kept in sync.""" + + def __init__(self, index, primary_var): + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access + self._primary_var = primary_var + super(MirroredVariable, self).__init__(index) + + # We use _get_update_device() for the assign* methods to enforce + # that we are in an update() function. The arguments to update() are + # automatically unwrapped so the update() function would normally + # see regular variables, not MirroredVariables. However, the update + # function can still operate on wrapped MirroredVariables through + # object members, captured arguments, etc. This is more likely in an + # update_non_slot() function (like OptimizerV2._finish), which can + # update several non-slot variables in one call. + def assign_sub(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get(device=_get_update_device()).assign(*args, **kwargs) + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + MirroredVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _MirroredSaveable(self, self._primary_var, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): + """Class for defining how to restore a TowerLocalVariable.""" + + def __init__(self, tower_local_variable, name): + self._tower_local_variable = tower_local_variable + # We use a callable so that we don't have to evaluate this expression + # in the case where we are trying to restore instead of save. + def tensor(): + return distribute_lib.get_distribution_strategy().fetch( + tower_local_variable) + spec = saver.BaseSaverBuilder.SaveSpec( + tensor=tensor, + slice_spec="", + name=name, + dtype=tower_local_variable.dtype) + super(_TowerLocalSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into all variables.""" + tensor, = restored_tensors + # To preserve the sum across save and restore, we have to divide the + # total across all devices when restoring a variable that was summed + # when saving. + if self._tower_local_variable.reduce_method == "sum": + tensor *= 1. / len(self._tower_local_variable.devices) + return control_flow_ops.group([ + _assign_on_device(d, v, tensor) + for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access + + +class TowerLocalVariable(DistributedVariable, PerDevice, + checkpointable.CheckpointableBase): + """Holds a map from device to variables whose values are reduced on save.""" + + def __init__(self, index, primary_var, reduce_method): + self._primary_var = primary_var + self._reduce_method = reduce_method + super(TowerLocalVariable, self).__init__(index) + + def assign_sub(self, *args, **kwargs): + return self.get().assign_sub(*args, **kwargs) + + def assign_add(self, *args, **kwargs): + return self.get().assign_add(*args, **kwargs) + + def assign(self, *args, **kwargs): + return self.get().assign(*args, **kwargs) + + @property + def reduce_method(self): + return self._reduce_method + + def _gather_saveables_for_checkpoint(self): + """Overrides CheckpointableBase method. + + This allows both name-based and object-based save and restore of + TowerLocalVariables. + + Returns: + A dictionary mapping attribute names to `SaveableObject` factories. + """ + def _saveable_factory(name=self._common_name): + return _TowerLocalSaveable(self, name) + return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} + + +def _devices_match(d1, d2): + return device_util.canonicalize(d1) == device_util.canonicalize(d2) + + +def regroup(per_device, wrap_class=PerDevice): + """Makes device->nest map into a nest of PerDevice/Mirrored values.""" + items = list(per_device.items()) + assert items + v0 = items[0][1] # First value + + if isinstance(v0, list): + for _, v in items[1:]: + assert isinstance(v, list) + assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % + (len(v), len(v0), v, v0)) + return [regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))] + + if isinstance(v0, tuple): + for _, v in items[1:]: + assert isinstance(v, tuple) + assert len(v) == len(v0) + regrouped_tuple = tuple(regroup({k: v[i] for k, v in items}, wrap_class) + for i in range(len(v0))) + if hasattr(v0, "_fields"): + # This tuple is in fact a namedtuple! Create a new namedtuple instance + # and initialize it with the regrouped values: + assert hasattr(type(v0), "_make") + return type(v0)._make(regrouped_tuple) + else: + return regrouped_tuple + + if isinstance(v0, dict): + v0keys = set(v0.keys()) + for _, v in items[1:]: + assert isinstance(v, dict) + assert set(v.keys()) == v0keys + return {key: regroup({k: v[key] for k, v in items}, wrap_class) + for key in v0keys} + + # If exactly the same object across all devices, return it unwrapped. + same_id = True + for _, v in items[1:]: + if v is not v0: + same_id = False + break + # Consider three cases where same_id is true: + # * If v0 is a MirroredVariable (and same_id means it is the same + # across all devices), we want to return it. We check + # MirroredVariable specifically since it can look like it + # has a _mirrored_container member since its members do. + # * If v0 is a member of a mirrored variable, in which case + # hasattr(v0, "_mirrored_container") is true, we want to + # return the MirroredVariable that contains it using the + # _mirrored_container logic below. This case can trigger + # same_id when there is only one device. + # * In any other situation, same_id means we return v0. + if same_id and (isinstance(v0, MirroredVariable) or + not hasattr(v0, "_mirrored_container")): + return v0 + + # Detect the case where each device has a parallel component of the + # same MirroredVariable. In this case we want to return the + # containing MirroredVariable, after a bunch of sanity checking. + # In particular, each component should have the same container, + # and the devices of the variables should match the keys of the + # per-device dictionary. + # TODO(josh11b): Do we need similar logic for TowerLocalVariables? + if hasattr(v0, "_mirrored_container"): + # pylint: disable=protected-access + assert not isinstance(v0, MirroredVariable), ( + "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) + assert _devices_match(v0.device, items[0][0]), ( + "v0.device = %s, items = %s" % (v0.device, items)) + mirrored_container = v0._mirrored_container() + assert mirrored_container is not None + for d, v in items[1:]: + assert _devices_match(v.device, d), ( + "v.device = %s, d = %s, items = %s" % (v.device, d, items)) + assert mirrored_container is v._mirrored_container() + return mirrored_container + # pylint: enable=protected-access + + return wrap_class(per_device) + + +def select_device(device, structured): + """Specialize a nest of regular & per-device values for one device.""" + def _get(x): + return x.get(device) if isinstance(x, DistributedValues) else x + + return nest.map_structure(_get, structured) + + +def select_device_mirrored(device, structured): + """Specialize a nest of regular & mirrored values for one device.""" + def _get_mirrored(x): + if isinstance(x, DistributedValues): + if not isinstance(x, Mirrored): + raise TypeError( + "Expected value to be mirrored across towers: %s in %s." % + (x, structured)) + return x.get(device) + else: + return x + + return nest.map_structure(_get_mirrored, structured) + + +class PerDeviceDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`.""" + + def __init__(self, iterator, devices, prefetch_on_device=None): + self._iterator = iterator + self._devices = devices + self._prefetch_on_device = prefetch_on_device + + def get_next(self, name=None): + """Scatter the input across devices.""" + if self._prefetch_on_device: + data_list = self._iterator.get_next(name=name) + index = dict(zip(self._devices, data_list)) + else: + batch = self._iterator.get_next(name=name) + index = {} + def get_ith(i): + return lambda x: x[i] + + for i, d in enumerate(self._devices): + index[d] = nest.map_structure(get_ith(i), batch) + if context.executing_eagerly(): + with ops.device(d): + index[d] = nest.map_structure(array_ops.identity, index[d]) + + return regroup(index) + + +class PerDeviceDataset(object): + """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" + + def __init__(self, dataset, devices, prefetch_on_device=None): + self._devices = devices + + # Default to using prefetching in graph mode, unless specified. + # TODO(priyag): Enable prefetching in eager mode. + self._prefetch_on_device = prefetch_on_device + if self._prefetch_on_device is None: + self._prefetch_on_device = not context.executing_eagerly() + assert not (self._prefetch_on_device and context.executing_eagerly()), ( + "Prefetching is only supported in graph mode currently") + + if self._prefetch_on_device: + self._dataset = dataset + else: + # TODO(priyag): If dropping remainder is not appropriate, find another + # approach to distributing the dataset when not possible to divide evenly. + # Possibly not an issue when we start using PartitionedDataset. + self._dataset = dataset.apply( + batching.batch_and_drop_remainder(len(devices))) + + def make_one_shot_iterator(self): + """Get a one time use iterator for the distributed PerDeviceDataset.""" + if self._prefetch_on_device: + on_device_dataset = self._dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + dataset_iterator = on_device_dataset.make_one_shot_iterator() + elif context.executing_eagerly(): + dataset_iterator = datasets.Iterator(self._dataset) + else: + dataset_iterator = self._dataset.make_one_shot_iterator() + + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + + +class MapOutput(object): + """Map can result in multiple outputs per device.""" + + def __init__(self, l): + self._l = l + + def get(self): + return self._l diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0d4b7d6c78b7cf63c613201d83d4793ecfe76b --- /dev/null +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -0,0 +1,807 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the distributed values library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import device_util +from tensorflow.python.training import saver as saver_lib + + +@test_util.with_c_api +class DistributedValuesTest(test.TestCase): + + def testGetEager(self): + with ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testGetGraph(self): + with context.graph_mode(), \ + ops.Graph().as_default(), \ + ops.device("/device:CPU:0"): + one = constant_op.constant(1) + two = constant_op.constant(2) + v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two}) + self.assertEqual(two, v.get("/device:GPU:0")) + self.assertEqual(one, v.get()) + with self.assertRaises(ValueError): + self.assertIsNone(v.get("/device:GPU:2")) + + def testCanonicalization(self): + canonical_cpu = ["/job:localhost/replica:0/task:0/device:CPU:0"] + v = values.DistributedValues({"": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/device:CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/cpu:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + v = values.DistributedValues({"/CPU:0": 42}) + self.assertEqual(canonical_cpu, list(v._index.keys())) + with self.assertRaises(AssertionError): + v = values.DistributedValues({"/device:cpu:0": 42}) + + +@test_util.with_c_api +class DistributedDelegateTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testGetAttr(self): + with ops.device("/device:CPU:0"): + + class Foo(object): + + def __init__(self, x): + self.x = x + + v = values.DistributedDelegate( + {"/device:CPU:0": Foo(7), "/device:GPU:0": Foo(8)}) + self.assertEqual(7, v.x) + with self.assertRaises(AttributeError): + _ = v.y + + @test_util.run_in_graph_and_eager_modes() + def testOperatorOverride(self): + with ops.device("/device:CPU:0"): + v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8}) + # v should act like int(7). + self.assertEqual(8, v + 1) + self.assertEqual(10, 3 + v) + self.assertEqual(14, v + v) + self.assertEqual(5, v - 2) + self.assertEqual(6, 13 - v) + self.assertEqual(0, v - v) + self.assertEqual(14, v * 2) + self.assertEqual(21, 3 * v) + self.assertEqual(49, v * v) + self.assertEqual(3.5, v / 2) + self.assertEqual(1.5, 10.5 / v) + self.assertEqual(3, v // 2) + self.assertEqual(2, 15 // v) + self.assertEqual(1, v % 2) + self.assertEqual(2, 16 % v) + self.assertTrue(v < 12) + self.assertTrue(v <= 12) + self.assertFalse(v > 12) + self.assertFalse(v >= 12) + self.assertFalse(12 < v) + self.assertFalse(12 <= v) + self.assertTrue(12 > v) + self.assertTrue(12 >= v) + self.assertEqual(3, v & 3) + self.assertEqual(3, 11 & v) + self.assertEqual(15, v | 8) + self.assertEqual(23, 16 | v) + self.assertEqual(4, v ^ 3) + self.assertEqual(12, 11 ^ v) + self.assertEqual(343, pow(v, 3)) + self.assertEqual(3, pow(v, 3, 10)) + self.assertEqual(128, pow(2, v)) + self.assertEqual(-7, -v) + self.assertEqual(~7, ~v) + self.assertEqual(7, abs(v)) + with self.assertRaises(TypeError): + _ = v[2] + + +def _device_str(d): + return "/device:GPU:" + str(d) + + +def _nested_value(d): + return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) + + +def _make_mirrored(): + v = [] + index = {} + devices = ["/device:GPU:0", "/device:CPU:0"] + for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + mirrored = values.MirroredVariable(index, v[0]) + return v, devices, mirrored + + +@test_util.with_c_api +class RegroupAndSelectDeviceTest(test.TestCase): + + def _is_per_device(self, result, expected, klass=values.PerDevice): + self.assertIsInstance(result, klass) + # We canonicalize the devices to match the device strings returned + # by PerDevice, which also does device string canonicalization. + devices = [device_util.canonicalize(_device_str(i)) + for i in range(len(expected))] + self.assertEqual(set(devices), set(result.devices)) + for i, d in enumerate(devices): + self.assertEqual(expected[i], result.get(d)) + self.assertEqual(expected[i], result.get(_device_str(i))) + + def testNested(self): + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"]) + self._is_per_device(result[2], ["h1", "h2"]) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"]) + self._is_per_device(result[1][2], ["g1", "g2"]) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"]) + self._is_per_device(result[1][1]["e"], ["f1", "f2"]) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # select_device_mirrored() should fail due to non-mirrored values + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(0), result) + with self.assertRaises(TypeError): + values.select_device_mirrored(_device_str(1), result) + + def testWrapClass(self): + # Normally a mirrored value would be the same across devices, but + # for a test it is convenient to be able to tell the values apart. + result = values.regroup({_device_str(0): _nested_value("1"), + _device_str(1): _nested_value("2")}, + values.Mirrored) + self.assertIsInstance(result, tuple) + self.assertEqual(3, len(result)) + self._is_per_device(result[0], ["a1", "a2"], values.Mirrored) + self._is_per_device(result[2], ["h1", "h2"], values.Mirrored) + + self.assertIsInstance(result[1], list) + self.assertEqual(3, len(result[1])) + self._is_per_device(result[1][0], ["b1", "b2"], values.Mirrored) + self._is_per_device(result[1][2], ["g1", "g2"], values.Mirrored) + + self.assertIsInstance(result[1][1], dict) + self.assertEqual(set(["c", "e"]), set(result[1][1].keys())) + self._is_per_device(result[1][1]["c"], ["d1", "d2"], values.Mirrored) + self._is_per_device(result[1][1]["e"], ["f1", "f2"], values.Mirrored) + + # Also test that we can undo the merge using select_device() + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device(_device_str(1), result)) + # Values are marked as mirrored, so select_device_mirrored() is allowed. + self.assertEqual(_nested_value("1"), + values.select_device_mirrored(_device_str(0), result)) + self.assertEqual(_nested_value("2"), + values.select_device_mirrored(_device_str(1), result)) + + def testMirroredContainer(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + v, devices, mirrored = _make_mirrored() + result = values.regroup(dict(zip(devices, v))) + self.assertIs(mirrored, result) + + def testSameId(self): + foo = object() + result = values.regroup({_device_str(0): ("a", foo), + _device_str(1): ("b", foo)}) + self.assertIsInstance(result, tuple) + self.assertEqual(2, len(result)) + self._is_per_device(result[0], ["a", "b"]) + self.assertIs(foo, result[1]) + + # Test select_device(), should undo the merge done by regroup(). + result_0 = values.select_device(_device_str(0), result) + self.assertIsInstance(result_0, tuple) + self.assertEqual(2, len(result_0)) + self.assertEqual("a", result_0[0]) + self.assertIs(foo, result_0[1]) + result_1 = values.select_device(_device_str(1), result) + self.assertIsInstance(result_1, tuple) + self.assertEqual(2, len(result_1)) + self.assertEqual("b", result_1[0]) + self.assertIs(foo, result_1[1]) + + def testOneDevice(self): + result = values.regroup({_device_str(0): _nested_value("1")}) + # On one device regroup() and select_device() are basically identity. + self.assertEqual(_nested_value("1"), result) + self.assertEqual(_nested_value("1"), + values.select_device(_device_str(0), result)) + + # The one exception has to do with MirroredVariables. + d = "/device:CPU:0" + with ops.device(d): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + index = {d: v} + mirrored = values.MirroredVariable(index, v) + result = values.regroup(index) + self.assertIs(mirrored, result) + + def testNamedTupleEstimatorSpec(self): + with context.graph_mode(), ops.Graph().as_default(): + created_estimator_specs = {} + to_regroup = {} + + for device_id in range(3): + spec = model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=constant_op.constant(device_id / 2), + train_op=array_ops.identity(constant_op.constant(device_id))) + created_estimator_specs[device_id] = spec + to_regroup[_device_str(device_id)] = spec + + merged_estimator_spec = values.regroup(to_regroup) + + self.assertTrue( + isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec)) + self.assertEquals(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode) + for device_id in range(3): + d = _device_str(device_id) + self.assertEquals(created_estimator_specs[device_id].loss, + merged_estimator_spec.loss.get(d)) + self.assertEquals(created_estimator_specs[device_id].train_op, + merged_estimator_spec.train_op.get(d)) + # Scaffold is populated by `EstimatorSpec.__new__`. + self.assertEquals(created_estimator_specs[device_id].scaffold, + merged_estimator_spec.scaffold.get(d)) + # Also test that we can undo the merge using select_device() + self.assertEquals(created_estimator_specs[device_id], + values.select_device(_device_str(device_id), + merged_estimator_spec)) + + +@test_util.with_c_api +class PerDeviceDatasetTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + def _test_iterator_no_prefetch(self, devices, dataset, expected_values): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=False) + iterator = per_device_dataset.make_one_shot_iterator() + + for expected_value in expected_values: + next_element = iterator.get_next() + actual = self.evaluate([ + values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator_with_prefetch(self, devices, dataset, expected_values): + if not context.executing_eagerly(): + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=True) + iterator = per_device_dataset.make_one_shot_iterator() + + # With prefetching, we cannot guarantee which input ends up on which + # device, so we verify that the complete set seen on all devices is + # correct, and equal numbers are distributed to each device. + combined_actual = [] + combined_expected = [] + for expected_value in expected_values: + next_element = iterator.get_next() + combined_actual.extend(self.evaluate([ + values.select_device(d, next_element) for d in devices])) + combined_expected.extend(expected_value) + + self.assertEqual(set(combined_expected), set(combined_actual)) + + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + self.evaluate([ + values.select_device(d, next_element) for d in devices]) + + def _test_iterator(self, devices, dataset, expected_values): + self._test_iterator_no_prefetch(devices, dataset, expected_values) + self._test_iterator_with_prefetch(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes() + def testOneDevice(self): + devices = ["/device:CPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i] for i in range(10)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testMultipleDevices(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(10) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testTupleDataset(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] + + self._test_iterator(devices, dataset, expected_values) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testUnevenDatasetBatches(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + devices = ["/device:CPU:0", "/device:GPU:0"] + dataset = dataset_ops.Dataset.range(11) + + expected_values = [[i, i+1] for i in range(0, 10, 2)] + self._test_iterator(devices, dataset, expected_values) + + +@test_util.with_c_api +class MirroredVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, _, mirrored = _make_mirrored() + + self.assertEquals(v[0].name, mirrored.name) + self.assertEquals(v[0].dtype, mirrored.dtype) + self.assertEquals(v[0].shape, mirrored.shape) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + mirrored = values.MirroredVariable(index, v) + + self.assertEquals(v.name, mirrored.name) + self.assertEquals(v.dtype, mirrored.dtype) + self.assertEquals(v.shape, mirrored.shape) + + def _assign_mirrored(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreMirroredOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path, saver = self._save_return_saver(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + + # Restores the saved value of 3. to both variables. + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + def _save_mirrored(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [3., 4.]) + + # Saves the current value of v[0], 3. + save_path = self._save(sess, mirrored) + + # Change the values between save and restore. + self._assign_mirrored(devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.)) + + # Saves the current value of var, 3. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3. to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3., self.evaluate(var)) + + def _restore_mirrored(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, devices, mirrored = _make_mirrored() + + # Overwrite the initial values. + self._assign_mirrored(devices, v, [7., 8.]) + + # Restores the saved value of 3. to both variables. + saver = saver_lib.Saver(var_list=[mirrored]) + saver.restore(sess, save_path) + self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_mirrored(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveMirroredRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_mirrored() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreMirrored(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_mirrored(save_path) + + +_devices = ["/device:GPU:0", "/device:CPU:0"] + + +def _make_tower_local(method): + v = [] + index = {} + for d, n, init in zip(_devices, ["v", "v/replica"], [1., 2.]): + with ops.device(d): + v.append(variable_scope.get_variable( + name=n, initializer=init, use_resource=True)) + index[d] = v[-1] + tower_local = values.TowerLocalVariable(index, v[0], method) + return v, tower_local + + +@test_util.with_c_api +class TowerLocalVariableTest(test.TestCase): + + config = config_pb2.ConfigProto() + config.allow_soft_placement = True + + @test_util.run_in_graph_and_eager_modes(config=config) + def testProperties(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + v, tower_local = _make_tower_local("sum") + + self.assertEquals(v[0].name, tower_local.name) + self.assertEquals(v[0].dtype, tower_local.dtype) + self.assertEquals(v[0].shape, tower_local.shape) + self.assertEquals("sum", tower_local.reduce_method) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testVariableOnAnotherDevice(self): + v = variable_scope.get_variable( + name="v", initializer=[1.], use_resource=True) + index = {"/job:foo/device:CPU:0": v} + tower_local = values.TowerLocalVariable(index, v, "mean") + + self.assertEquals(v.name, tower_local.name) + self.assertEquals(v.dtype, tower_local.dtype) + self.assertEquals(v.shape, tower_local.shape) + self.assertEquals("mean", tower_local.reduce_method) + + def _assign_tower_local(self, devices, v, new): + for d, var, n in zip(devices, v, new): + with ops.device(d): + self.evaluate(var.assign(n)) + + def _save_return_saver(self, sess, var): + saver = saver_lib.Saver(var_list=[var]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + return saver.save(sess, prefix), saver + + def _save(self, sess, var): + save_path, _ = self._save_return_saver(sess, var) + return save_path + + def _dist_scope(self): + return mirrored_strategy.MirroredStrategy(_devices).scope() + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalSumOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 7. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 7. which gets divided equally + # between the variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveAndRestoreTowerLocalMeanOneGraph(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + with self.test_session() as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5. + save_path, saver = self._save_return_saver(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + + # Restores the saved value of 3.5 to both variables. + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _save_tower_local_mean(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [3., 4.]) + + with self._dist_scope(): + # Saves the current value of (v[0] + v[1])/2, 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_tower_local_sum(self): + """Save variables with mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [1.5, 2.]) + + with self._dist_scope(): + # Saves the current value of v[0] + v[1], 3.5 + save_path = self._save(sess, tower_local) + + # Change the values between save and restore. + self._assign_tower_local(_devices, v, [5., 6.]) + return save_path + + def _save_normal(self): + """Save variables without mirroring, returns save_path.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(3.5)) + + # Saves the current value of var, 3.5. + save_path = self._save(sess, var) + + # Change the values between save and restore. + self.evaluate(var.assign(5.)) + return save_path + + def _restore_normal(self, save_path): + """Restore to variables without mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + var = variable_scope.get_variable( + name="v", initializer=7., use_resource=True) + + # Overwrite the initial value. + self.evaluate(var.assign(8.)) + + # Restores the saved value of 3.5 to `var`. + saver = saver_lib.Saver(var_list=[var]) + saver.restore(sess, save_path) + self.assertEqual(3.5, self.evaluate(var)) + + def _restore_tower_local_mean(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("mean") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) + + def _restore_tower_local_sum(self, save_path): + """Restore to variables with mirroring in a fresh graph.""" + with self.test_session(graph=ops.Graph()) as sess: + v, tower_local = _make_tower_local("sum") + + # Overwrite the initial values. + self._assign_tower_local(_devices, v, [7., 8.]) + + with self._dist_scope(): + # Restores the saved value of 3.5 to both variables. + saver = saver_lib.Saver(var_list=[tower_local]) + saver.restore(sess, save_path) + self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_tower_local_sum(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalMeanRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_mean() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveTowerLocalSumRestoreNormal(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_tower_local_sum() + self._restore_normal(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalMean(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_mean(save_path) + + @test_util.run_in_graph_and_eager_modes(config=config) + def testSaveNormalRestoreTowerLocalSum(self): + if context.num_gpus() < 1 and context.executing_eagerly(): + self.skipTest("A GPU is not available for this test in eager mode.") + + save_path = self._save_normal() + self._restore_tower_local_sum(save_path) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 7f510c42215f48a9e795eb81bd9f66b0a2108335..20e432b88dc60d45fd32710574ed6e57d0f8a792 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -251,6 +251,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_test", + srcs = ["python/kernel_tests/kumaraswamy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "moving_stats_test", size = "small", @@ -335,6 +350,7 @@ cuda_py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", ], + tags = ["nomsan"], ) cuda_py_test( @@ -403,7 +419,7 @@ cuda_py_test( cuda_py_test( name = "poisson_lognormal_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/poisson_lognormal_test.py"], additional_deps = [ ":distributions_py", @@ -438,6 +454,21 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = ["no_windows"], # TODO: needs investigation on Windows +) + +cuda_py_test( + name = "batch_reshape_test", + size = "small", + srcs = ["python/kernel_tests/batch_reshape_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], ) cuda_py_test( @@ -459,6 +490,30 @@ cuda_py_test( tags = ["nomsan"], # disable to avoid false positives from scipy. ) +cuda_py_test( + name = "seed_stream_test", + size = "small", + srcs = ["python/kernel_tests/seed_stream_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:client_testlib", + ], +) + +cuda_py_test( + name = "statistical_testing_test", + size = "medium", + srcs = [ + "python/kernel_tests/statistical_testing_test.py", + ], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + ], + shard_count = 4, +) + cuda_py_test( name = "vector_sinh_arcsinh_diag_test", size = "medium", @@ -710,18 +765,6 @@ cuda_py_test( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - # === Bijector Tests ========================================================== cuda_py_test( @@ -782,6 +825,25 @@ cuda_py_test( tags = ["noasan"], # times out b/63678675 ) +cuda_py_test( + name = "affine_scalar_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/affine_scalar_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "affine_linear_operator_test", size = "small", @@ -801,6 +863,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "batch_normalization_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "chain_test", size = "small", @@ -915,6 +993,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_bijector_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "masked_autoregressive_test", size = "small", @@ -984,7 +1081,7 @@ cuda_py_test( cuda_py_test( name = "reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/reshape_test.py"], additional_deps = [ ":bijectors_py", @@ -1017,10 +1114,12 @@ cuda_py_test( ], ) +# Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to +# avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( - name = "sigmoid_centered_test", + name = "sinh_arcsinh_bijector_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sigmoid_centered_test.py"], + srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1034,14 +1133,13 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) -# Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to -# avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( - name = "sinh_arcsinh_bijector_test", + name = "softmax_centered_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py"], + srcs = ["python/kernel_tests/bijectors/softmax_centered_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1058,9 +1156,9 @@ cuda_py_test( ) cuda_py_test( - name = "softmax_centered_test", + name = "softplus_test", size = "small", - srcs = ["python/kernel_tests/bijectors/softmax_centered_test.py"], + srcs = ["python/kernel_tests/bijectors/softplus_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", @@ -1077,9 +1175,28 @@ cuda_py_test( ) cuda_py_test( - name = "softplus_test", + name = "softsign_test", size = "small", - srcs = ["python/kernel_tests/bijectors/softplus_test.py"], + srcs = ["python/kernel_tests/bijectors/softsign_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "square_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/square_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 61c411271d0bb8d7b4cc3b14992b82ec1e5674ed..ddf59891e626a85e6c917ac74b3cfaabf16eb15d 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -24,6 +24,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.autoregressive import * +from tensorflow.contrib.distributions.python.ops.batch_reshape import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * @@ -58,6 +59,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.sample_stats import * +from tensorflow.contrib.distributions.python.ops.seed_stream import * from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.test_util import * from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * @@ -96,9 +98,10 @@ _allowed_symbols = [ 'ReparameterizationType', 'Distribution', 'Autoregressive', - 'Binomial', + 'BatchReshape', 'Bernoulli', 'Beta', + 'Binomial', 'BetaWithSoftplusConcentration', 'Categorical', 'Chi2', @@ -124,6 +127,7 @@ _allowed_symbols = [ 'NormalWithSoftplusScale', 'Poisson', 'PoissonLogNormalQuadratureCompound', + 'SeedStream', 'SinhArcsinh', 'StudentT', 'StudentTWithAbsDfSoftplusScale', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..59d549b7b80a3d80d0b8409542eb6583f645bdaa --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -0,0 +1,568 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BatchReshape.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import batch_reshape as batch_reshape_lib +from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_lib +from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.contrib.distributions.python.ops import wishart as wishart_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class _BatchReshapeTest(object): + + def make_wishart(self, dims, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = self.dtype([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale], axis=0), + old_batch_shape + [dims, dims]) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + wishart = wishart_lib.WishartFull(df=5, scale=scale_ph) + reshape_wishart = batch_reshape_lib.BatchReshape( + distribution=wishart, + batch_shape=new_batch_shape_ph, + validate_args=True) + + return wishart, reshape_wishart + + def test_matrix_variate_sample_and_log_prob(self): + dims = 2 + new_batch_shape = [4] + old_batch_shape = [2, 2] + wishart, reshape_wishart = self.make_wishart( + dims, new_batch_shape, old_batch_shape) + + batch_shape = reshape_wishart.batch_shape_tensor() + event_shape = reshape_wishart.event_shape_tensor() + + expected_sample_shape = [3, 1] + new_batch_shape + [dims, dims] + x = wishart.sample([3, 1], seed=42) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_wishart.sample([3, 1], seed=42) + + expected_log_prob_shape = [3, 1] + new_batch_shape + expected_log_prob = array_ops.reshape( + wishart.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_wishart.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([dims, dims], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_wishart.batch_shape) + self.assertAllEqual([dims, dims], reshape_wishart.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_matrix_variate_stats(self): + dims = 2 + new_batch_shape = [4] + old_batch_shape = [2, 2] + wishart, reshape_wishart = self.make_wishart( + dims, new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + expected_matrix_stat_shape = new_batch_shape + [dims, dims] + + expected_entropy = array_ops.reshape( + wishart.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_wishart.entropy() + + expected_mean = array_ops.reshape( + wishart.mean(), expected_matrix_stat_shape) + actual_mean = reshape_wishart.mean() + + expected_mode = array_ops.reshape( + wishart.mode(), expected_matrix_stat_shape) + actual_mode = reshape_wishart.mode() + + expected_stddev = array_ops.reshape( + wishart.stddev(), expected_matrix_stat_shape) + actual_stddev = reshape_wishart.stddev() + + expected_variance = array_ops.reshape( + wishart.variance(), expected_matrix_stat_shape) + actual_variance = reshape_wishart.variance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + ]) + + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_variance.shape) + + def make_normal(self, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = self.dtype(0.5 + np.arange( + np.prod(old_batch_shape)).reshape(old_batch_shape)) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + normal = normal_lib.Normal(loc=self.dtype(0), scale=scale_ph) + reshape_normal = batch_reshape_lib.BatchReshape( + distribution=normal, + batch_shape=new_batch_shape_ph, + validate_args=True) + return normal, reshape_normal + + def test_scalar_variate_sample_and_log_prob(self): + new_batch_shape = [2, 2] + old_batch_shape = [4] + + normal, reshape_normal = self.make_normal( + new_batch_shape, old_batch_shape) + + batch_shape = reshape_normal.batch_shape_tensor() + event_shape = reshape_normal.event_shape_tensor() + + expected_sample_shape = new_batch_shape + x = normal.sample(seed=52) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_normal.sample(seed=52) + + expected_log_prob_shape = new_batch_shape + expected_log_prob = array_ops.reshape( + normal.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_normal.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_normal.batch_shape) + self.assertAllEqual([], reshape_normal.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_scalar_variate_stats(self): + new_batch_shape = [2, 2] + old_batch_shape = [4] + + normal, reshape_normal = self.make_normal(new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + + expected_entropy = array_ops.reshape( + normal.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_normal.entropy() + + expected_mean = array_ops.reshape( + normal.mean(), expected_scalar_stat_shape) + actual_mean = reshape_normal.mean() + + expected_mode = array_ops.reshape( + normal.mode(), expected_scalar_stat_shape) + actual_mode = reshape_normal.mode() + + expected_stddev = array_ops.reshape( + normal.stddev(), expected_scalar_stat_shape) + actual_stddev = reshape_normal.stddev() + + expected_variance = array_ops.reshape( + normal.variance(), expected_scalar_stat_shape) + actual_variance = reshape_normal.variance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + ]) + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_scalar_stat_shape, actual_variance.shape) + + def make_mvn(self, dims, new_batch_shape, old_batch_shape): + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + reshape_mvn = batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + return mvn, reshape_mvn + + def test_vector_variate_sample_and_log_prob(self): + dims = 3 + new_batch_shape = [2, 1] + old_batch_shape = [2] + mvn, reshape_mvn = self.make_mvn( + dims, new_batch_shape, old_batch_shape) + + batch_shape = reshape_mvn.batch_shape_tensor() + event_shape = reshape_mvn.event_shape_tensor() + + expected_sample_shape = [3] + new_batch_shape + [dims] + x = mvn.sample(3, seed=62) + expected_sample = array_ops.reshape(x, expected_sample_shape) + actual_sample = reshape_mvn.sample(3, seed=62) + + expected_log_prob_shape = [3] + new_batch_shape + expected_log_prob = array_ops.reshape( + mvn.log_prob(x), expected_log_prob_shape) + actual_log_prob = reshape_mvn.log_prob(expected_sample) + + with self.test_session() as sess: + [ + batch_shape_, + event_shape_, + expected_sample_, actual_sample_, + expected_log_prob_, actual_log_prob_, + ] = sess.run([ + batch_shape, + event_shape, + expected_sample, actual_sample, + expected_log_prob, actual_log_prob, + ]) + self.assertAllEqual(new_batch_shape, batch_shape_) + self.assertAllEqual([dims], event_shape_) + self.assertAllClose(expected_sample_, actual_sample_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_log_prob_, actual_log_prob_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(new_batch_shape, reshape_mvn.batch_shape) + self.assertAllEqual([dims], reshape_mvn.event_shape) + self.assertAllEqual(expected_sample_shape, actual_sample.shape) + self.assertAllEqual(expected_log_prob_shape, actual_log_prob.shape) + + def test_vector_variate_stats(self): + dims = 3 + new_batch_shape = [2, 1] + old_batch_shape = [2] + mvn, reshape_mvn = self.make_mvn( + dims, new_batch_shape, old_batch_shape) + + expected_scalar_stat_shape = new_batch_shape + + expected_entropy = array_ops.reshape( + mvn.entropy(), expected_scalar_stat_shape) + actual_entropy = reshape_mvn.entropy() + + expected_vector_stat_shape = new_batch_shape + [dims] + + expected_mean = array_ops.reshape( + mvn.mean(), expected_vector_stat_shape) + actual_mean = reshape_mvn.mean() + + expected_mode = array_ops.reshape( + mvn.mode(), expected_vector_stat_shape) + actual_mode = reshape_mvn.mode() + + expected_stddev = array_ops.reshape( + mvn.stddev(), expected_vector_stat_shape) + actual_stddev = reshape_mvn.stddev() + + expected_variance = array_ops.reshape( + mvn.variance(), expected_vector_stat_shape) + actual_variance = reshape_mvn.variance() + + expected_matrix_stat_shape = new_batch_shape + [dims, dims] + + expected_covariance = array_ops.reshape( + mvn.covariance(), expected_matrix_stat_shape) + actual_covariance = reshape_mvn.covariance() + + with self.test_session() as sess: + [ + expected_entropy_, actual_entropy_, + expected_mean_, actual_mean_, + expected_mode_, actual_mode_, + expected_stddev_, actual_stddev_, + expected_variance_, actual_variance_, + expected_covariance_, actual_covariance_, + ] = sess.run([ + expected_entropy, actual_entropy, + expected_mean, actual_mean, + expected_mode, actual_mode, + expected_stddev, actual_stddev, + expected_variance, actual_variance, + expected_covariance, actual_covariance, + ]) + self.assertAllClose(expected_entropy_, actual_entropy_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mean_, actual_mean_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_mode_, actual_mode_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_stddev_, actual_stddev_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_variance_, actual_variance_, + atol=0., rtol=1e-6) + self.assertAllClose(expected_covariance_, actual_covariance_, + atol=0., rtol=1e-6) + if not self.is_static_shape: + return + self.assertAllEqual(expected_scalar_stat_shape, actual_entropy.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_mean.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_mode.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_stddev.shape) + self.assertAllEqual(expected_vector_stat_shape, actual_variance.shape) + self.assertAllEqual(expected_matrix_stat_shape, actual_covariance.shape) + + def test_bad_reshape_size(self): + dims = 2 + new_batch_shape = [2, 3] + old_batch_shape = [2] # 2 != 2*3 + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp( + ValueError, (r"`batch_shape` size \(6\) must match " + r"`distribution\.batch_shape` size \(2\)")): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r"`batch_shape` size must match " + r"`distributions.batch_shape` size"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_non_positive_shape(self): + dims = 2 + new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. + old_batch_shape = [2] + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r".*must be positive.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_non_vector_shape(self): + dims = 2 + new_batch_shape = 2 + old_batch_shape = [2] + + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + + scale = np.ones(old_batch_shape + [dims], self.dtype) + scale_ph = array_ops.placeholder_with_default( + scale, shape=scale.shape if self.is_static_shape else None) + mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) + + if self.is_static_shape: + with self.assertRaisesRegexp(ValueError, r".*must be a vector.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True) + + else: + with self.test_session(): + with self.assertRaisesOpError(r".*must be a vector.*"): + batch_reshape_lib.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape_ph, + validate_args=True).sample().eval() + + def test_broadcasting_explicitly_unsupported(self): + old_batch_shape = [4] + new_batch_shape = [1, 4, 1] + rate_ = self.dtype([1, 10, 2, 20]) + + rate = array_ops.placeholder_with_default( + rate_, + shape=old_batch_shape if self.is_static_shape else None) + poisson_4 = poisson_lib.Poisson(rate) + new_batch_shape_ph = ( + constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape + else array_ops.placeholder_with_default( + np.int32(new_batch_shape), shape=None)) + poisson_141_reshaped = batch_reshape_lib.BatchReshape( + poisson_4, new_batch_shape_ph, validate_args=True) + + x_4 = self.dtype([2, 12, 3, 23]) + x_114 = self.dtype([2, 12, 3, 23]).reshape(1, 1, 4) + + if self.is_static_shape: + with self.assertRaisesRegexp(NotImplementedError, + "too few batch and event dims"): + poisson_141_reshaped.log_prob(x_4) + with self.assertRaisesRegexp(NotImplementedError, + "unexpected batch and event shape"): + poisson_141_reshaped.log_prob(x_114) + return + + with self.assertRaisesOpError("too few batch and event dims"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_4).eval() + + with self.assertRaisesOpError("unexpected batch and event shape"): + with self.test_session(): + poisson_141_reshaped.log_prob(x_114).eval() + + +class BatchReshapeStaticTest(_BatchReshapeTest, test.TestCase): + + dtype = np.float32 + is_static_shape = True + + +class BatchReshapeDynamicTest(_BatchReshapeTest, test.TestCase): + + dtype = np.float64 + is_static_shape = False + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py index e0d65c79b2654c2949de161d6317f218d11cab43..042c8ebd51c47facfc5c942cae56bd56be9df7c5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -18,11 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - # pylint: disable=g-importing-member from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import AbsoluteValue -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,50 +32,38 @@ class AbsoluteValueTest(test.TestCase): def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) self.assertEqual("absolute_value", bijector.name) x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3] y = math_ops.abs(x) y_ = y.eval() - zeros = np.zeros((2, 3)) self.assertAllClose(y_, bijector.forward(x).eval()) self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) - self.assertAllClose((zeros, zeros), - sess.run(bijector.inverse_log_det_jacobian(y))) + self.assertAllClose((0., 0.), + sess.run(bijector.inverse_log_det_jacobian( + y, event_ndims=0))) # Run things twice to make sure there are no issues in caching the tuples # returned by .inverse* self.assertAllClose(y_, bijector.forward(x).eval()) self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) - self.assertAllClose((zeros, zeros), - sess.run(bijector.inverse_log_det_jacobian(y))) - - def testEventNdimsMustBeZeroOrRaiseStatic(self): - with self.test_session(): - with self.assertRaisesRegexp(ValueError, "event_ndims.*was not 0"): - AbsoluteValue(event_ndims=1) - - def testEventNdimsMustBeZeroOrRaiseDynamic(self): - with self.test_session() as sess: - event_ndims = array_ops.placeholder(dtypes.int32) - abs_bijector = AbsoluteValue(event_ndims=event_ndims, validate_args=True) - with self.assertRaisesOpError("event_ndims was not 0"): - sess.run(abs_bijector.inverse_log_det_jacobian([1.]), - feed_dict={event_ndims: 1}) + self.assertAllClose((0., 0.), + sess.run(bijector.inverse_log_det_jacobian( + y, event_ndims=0))) def testNegativeYRaisesForInverseIfValidateArgs(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): sess.run(bijector.inverse(-1.)) def testNegativeYRaisesForILDJIfValidateArgs(self): with self.test_session() as sess: - bijector = AbsoluteValue(event_ndims=0, validate_args=True) + bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): - sess.run(bijector.inverse_log_det_jacobian(-1.)) + sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 405ddd292cacd8ace87d6caeebf3e8cfc347c22d..1e4ad724d00f751a55370ef9aa6dde0003a2098c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -38,9 +38,11 @@ class AffineLinearOperatorTest(test.TestCase): self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose(ildj, affine.inverse_log_det_jacobian( + y, event_ndims=2).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(), + affine.forward_log_det_jacobian(x, event_ndims=2).eval()) def testDiag(self): with self.test_session(): @@ -58,14 +60,16 @@ class AffineLinearOperatorTest(test.TestCase): self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, affine.inverse_log_det_jacobian(y, event_ndims=1).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=1).eval(), + affine.forward_log_det_jacobian(x, event_ndims=1).eval()) def testTriL(self): with self.test_session(): shift = np.array([-1, 0, 1], dtype=np.float32) - tril = np.array([[[1, 0, 0], + tril = np.array([[[3, 0, 0], [2, -1, 0], [3, 2, 1]], [[2, 0, 0], @@ -85,15 +89,17 @@ class AffineLinearOperatorTest(test.TestCase): # y = np.matmul(x, tril) + shift. y = np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift ildj = -np.sum(np.log(np.abs(np.diagonal( - tril, axis1=-2, axis2=-1))), - axis=-1) + tril, axis1=-2, axis2=-1)))) self.assertEqual(affine.name, "affine_linear_operator") self.assertAllClose(y, affine.forward(x).eval()) self.assertAllClose(x, affine.inverse(y).eval()) - self.assertAllClose(ildj, affine.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-affine.inverse_log_det_jacobian(y).eval(), - affine.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, affine.inverse_log_det_jacobian( + y, event_ndims=2).eval()) + self.assertAllClose( + -affine.inverse_log_det_jacobian(y, event_ndims=2).eval(), + affine.forward_log_det_jacobian(x, event_ndims=2).eval()) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d2533620bebeb0400b6d4a6346e8315c7e37c5c6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py @@ -0,0 +1,160 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Affine Scalar Tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import AffineScalar +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class AffineScalarBijectorTest(test.TestCase): + """Tests correctness of the Y = scale @ x + shift transformation.""" + + def testProperties(self): + with self.test_session(): + mu = -1. + # scale corresponds to 1. + bijector = AffineScalar(shift=mu) + self.assertEqual("affine_scalar", bijector.name) + + def testNoBatchScalar(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = -1. + # Corresponds to scale = 2 + bijector = AffineScalar(shift=mu, scale=2.) + x = [1., 2, 3] # Three scalar samples (no batches). + self.assertAllClose([1., 3, 5], run(bijector.forward, x)) + self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float64) + x = array_ops.placeholder(dtypes.float64, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = np.float64([1.]) + # One batch, scalar. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu) + x = np.float64([1.]) # One sample from one batches. + self.assertAllClose([2.], run(bijector.forward, x)) + self.assertAllClose([0.], run(bijector.inverse, x)) + self.assertAllClose( + 0., + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float64) + x = array_ops.placeholder(dtypes.float64, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + multiplier = np.float64([2.]) + # One batch, scalar. + # Corresponds to scale = 2, shift = 0. + bijector = AffineScalar(scale=multiplier) + x = np.float64([1.]) # One sample from one batches. + self.assertAllClose([2.], run(bijector.forward, x)) + self.assertAllClose([0.5], run(bijector.inverse, x)) + self.assertAllClose( + [np.log(0.5)], + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testTwoBatchScalarIdentityViaIdentity(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float32) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = [1., -1] + # Univariate, two batches. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu) + x = [1., 1] # One sample from each of two batches. + self.assertAllClose([2., 0], run(bijector.forward, x)) + self.assertAllClose([0., 2], run(bijector.inverse, x)) + self.assertAllClose( + 0., + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testTwoBatchScalarIdentityViaScale(self): + with self.test_session() as sess: + + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() + + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value).astype(np.float32) + x = array_ops.placeholder(dtypes.float32, name="x") + return sess.run(fun(x, **kwargs), feed_dict={x: x_value}) + + for run in (static_run, dynamic_run): + mu = [1., -1] + # Univariate, two batches. + # Corresponds to scale = 1. + bijector = AffineScalar(shift=mu, scale=[2., 1]) + x = [1., 1] # One sample from each of two batches. + self.assertAllClose([3., 0], run(bijector.forward, x)) + self.assertAllClose([0., 2], run(bijector.inverse, x)) + self.assertAllClose( + [-np.log(2), 0.], + run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) + + def testScalarCongruency(self): + with self.test_session(): + bijector = AffineScalar(shift=3.6, scale=0.42) + assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index c9158117f7a982e37047e8dd2b534a30040a87d9..9e14b9a53e6c63876478d876030c476c5d77dbbb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -25,7 +25,6 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -36,209 +35,26 @@ class AffineBijectorTest(test.TestCase): with self.test_session(): mu = -1. # scale corresponds to 1. - bijector = Affine(shift=mu, event_ndims=0) + bijector = Affine(shift=mu) self.assertEqual("affine", bijector.name) - def testNoBatchScalarViaIdentity(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2 - bijector = Affine( - shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 2, 3] # Three scalar samples (no batches). - self.assertAllClose([1., 3, 5], run(bijector.forward, x)) - self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testNoBatchScalarViaDiag(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2 - bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 2, 3] # Three scalar samples (no batches). - self.assertAllClose([1., 3, 5], run(bijector.forward, x)) - self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testWeirdSampleNoBatchScalarViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = -1. - # Corresponds to scale = 2. - bijector = Affine( - shift=mu, scale_identity_multiplier=2., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [[1., 2, 3], [4, 5, 6]] # Weird sample shape. - self.assertAllClose([[1., 3, 5], - [7, 9, 11]], - run(bijector.forward, x)) - self.assertAllClose([[1., 1.5, 2.], - [2.5, 3, 3.5]], - run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value).astype(np.float64) - x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = np.float64([1.]) - # One batch, scalar. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = np.float64([1.]) # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaIdentityIn64BitUserProvidesMultiplierOnly(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value).astype(np.float64) - x = array_ops.placeholder(dtypes.float64, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - multiplier = np.float64([2.]) - # One batch, scalar. - # Corresponds to scale = 2, shift = 0. - bijector = Affine(scale_identity_multiplier=multiplier, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = np.float64([1.]) # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.5], run(bijector.inverse, x)) - self.assertAllClose([np.log(0.5)], - run(bijector.inverse_log_det_jacobian, x)) - - def testOneBatchScalarViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1.] - # One batch, scalar. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1.] # One sample from one batches. - self.assertAllClose([2.], run(bijector.forward, x)) - self.assertAllClose([0.], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testTwoBatchScalarIdentityViaIdentity(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1., -1] - # Univariate, two batches. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 1] # One sample from each of two batches. - self.assertAllClose([2., 0], run(bijector.forward, x)) - self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - - def testTwoBatchScalarIdentityViaDiagMultiplier(self): - with self.test_session() as sess: - - def static_run(fun, x): - return fun(x).eval() - - def dynamic_run(fun, x_value): - x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) - - for run in (static_run, dynamic_run): - mu = [1., -1] - # Univariate, two batches. - # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" - x = [1., 1] # One sample from each of two batches. - self.assertAllClose([2., 0], run(bijector.forward, x)) - self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - def testNoBatchMultivariateIdentity(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] # Multivariate # Corresponds to scale = [[1., 0], [0, 1.]] bijector = Affine(shift=mu) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] @@ -251,33 +67,37 @@ class AffineBijectorTest(test.TestCase): x = [[1., 1], [-1., -1]] self.assertAllClose([[2., 0], [0., -2]], run(bijector.forward, x)) self.assertAllClose([[0., 2], [-2., 0]], run(bijector.inverse, x)) - self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + 0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateDiag(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [1., -1] # Multivariate # Corresponds to scale = [[2., 0], [0, 1.]] bijector = Affine(shift=mu, scale_diag=[2., 1]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 1] # matmul(sigma, x) + shift # = [-1, -1] + [1, -1] self.assertAllClose([3., 0], run(bijector.forward, x)) self.assertAllClose([0., 2], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + # Reset bijector. + bijector = Affine(shift=mu, scale_diag=[2., 1]) # x is a 2-batch of 2-vectors. # The first vector is [1, 1], the second is [-1, -1]. # Each undergoes matmul(sigma, x) + shift. @@ -289,120 +109,116 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([[0., 2], [-1., 0]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateFullDynamic(self): with self.test_session() as sess: x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") - event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims") x_value = np.array([[1., 1]], dtype=np.float32) mu_value = np.array([1., -1], dtype=np.float32) scale_diag_value = np.array([2., 2], dtype=np.float32) - event_ndims_value = np.array(1, dtype=np.int32) feed_dict = { x: x_value, mu: mu_value, scale_diag: scale_diag_value, - event_ndims: event_ndims_value } - bijector = Affine( - shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) + bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[3., 1]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[0., 1]], sess.run(bijector.inverse(x), feed_dict)) self.assertAllClose( -np.log(4), - sess.run(bijector.inverse_log_det_jacobian(x), feed_dict)) + sess.run(bijector.inverse_log_det_jacobian(x, event_ndims=1), + feed_dict)) def testBatchMultivariateIdentity(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value, dtype=np.float32) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [[1., -1]] # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale = 2. bijector = Affine(shift=mu, scale_identity_multiplier=scale) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(4), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(4), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateDiag(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): - x_value = np.array(x_value, dtype=np.float32) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + def dynamic_run(fun, x_value, **kwargs): + x_value = np.array(x_value) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = [[1., -1]] # Corresponds to 1 2x2 matrix, with twos on the diagonal. scale_diag = [[2., 2]] bijector = Affine(shift=mu, scale_diag=scale_diag) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[[1., 1]]] self.assertAllClose([[[3., 1]]], run(bijector.forward, x)) self.assertAllClose([[[0., 1]]], run(bijector.inverse, x)) - self.assertAllClose([-np.log(4)], - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + [-np.log(4)], + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateFullDynamic(self): with self.test_session() as sess: x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") - event_ndims = array_ops.placeholder(dtypes.int32, name="event_ndims") x_value = np.array([[[1., 1]]], dtype=np.float32) mu_value = np.array([[1., -1]], dtype=np.float32) scale_diag_value = np.array([[2., 2]], dtype=np.float32) - event_ndims_value = 1 feed_dict = { x: x_value, mu: mu_value, scale_diag: scale_diag_value, - event_ndims: event_ndims_value } - bijector = Affine( - shift=mu, scale_diag=scale_diag, event_ndims=event_ndims) - self.assertEqual(1, sess.run(bijector.event_ndims, feed_dict)) + bijector = Affine(shift=mu, scale_diag=scale_diag) self.assertAllClose([[[3., 1]]], sess.run(bijector.forward(x), feed_dict)) self.assertAllClose([[[0., 1]]], sess.run(bijector.inverse(x), feed_dict)) - self.assertAllClose([-np.log(4)], - sess.run( - bijector.inverse_log_det_jacobian(x), feed_dict)) + self.assertAllClose( + [-np.log(4)], + sess.run(bijector.inverse_log_det_jacobian( + x, event_ndims=1), feed_dict)) def testIdentityWithDiagUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -410,25 +226,25 @@ class AffineBijectorTest(test.TestCase): bijector = Affine( shift=mu, scale_identity_multiplier=1., - scale_diag=[1., 1., 1.], - event_ndims=1) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" + scale_diag=[1., 1., 1.]) x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.**3), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(2.**3), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -437,46 +253,48 @@ class AffineBijectorTest(test.TestCase): shift=mu, scale_identity_multiplier=1., scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 5]], run(bijector.forward, x)) self.assertAllClose([[1., 0.5]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(4.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(4.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. # scale = [[2., 0], [2, 3]] bijector = Affine( shift=mu, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[1., 7]], run(bijector.forward, x)) self.assertAllClose([[1., 1 / 3.]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(6.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(6.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityAndDiagWithTriL(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -486,23 +304,24 @@ class AffineBijectorTest(test.TestCase): scale_identity_multiplier=1.0, scale_diag=[1., 2.], scale_tril=[[1., 0], [2., 1]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [[1., 2]] # One multivariate sample. self.assertAllClose([[2., 9]], run(bijector.forward, x)) self.assertAllClose([[2 / 3., 5 / 12.]], run(bijector.inverse, x)) - self.assertAllClose(-np.log(12.), - run(bijector.inverse_log_det_jacobian, x)) + self.assertAllClose( + -np.log(12.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -514,7 +333,6 @@ class AffineBijectorTest(test.TestCase): scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = Affine(shift=mu, scale_diag=[10., 2, 3]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 3, 8], run(bijector.forward, x)) self.assertAllClose( @@ -523,22 +341,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 1.5, 4 / 3.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(60.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(60.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -550,7 +370,6 @@ class AffineBijectorTest(test.TestCase): scale_perturb_factor=[[2., 0], [0., 0], [0, 1]]) bijector_ref = Affine(shift=mu, scale_diag=[10., 3, 5]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 5, 14], run(bijector.forward, x)) self.assertAllClose( @@ -558,22 +377,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 1., 0.8], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(150.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(150.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdate(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -586,7 +407,6 @@ class AffineBijectorTest(test.TestCase): bijector_ref = Affine( shift=mu, scale_tril=[[10., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([9., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -594,22 +414,24 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0.2, 14 / 15., 4 / 25.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(150.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(150.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdateNoDiagonal(self): with self.test_session() as sess: + placeholder = array_ops.placeholder(dtypes.float32, name="x") - def static_run(fun, x): - return fun(x).eval() + def static_run(fun, x, **kwargs): + return fun(x, **kwargs).eval() - def dynamic_run(fun, x_value): + def dynamic_run(fun, x_value, **kwargs): x_value = np.array(x_value) - x = array_ops.placeholder(dtypes.float32, name="x") - return sess.run(fun(x), feed_dict={x: x_value}) + return sess.run( + fun(placeholder, **kwargs), feed_dict={placeholder: x_value}) for run in (static_run, dynamic_run): mu = -1. @@ -622,7 +444,6 @@ class AffineBijectorTest(test.TestCase): bijector_ref = Affine( shift=mu, scale_tril=[[6., 0, 0], [1, 3, 0], [2, 3, 5]]) - self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Vector. self.assertAllClose([5., 6, 22], run(bijector.forward, x)) self.assertAllClose( @@ -630,11 +451,12 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([1 / 3., 8 / 9., 4 / 30.], run(bijector.inverse, x)) self.assertAllClose( run(bijector_ref.inverse, x), run(bijector.inverse, x)) - self.assertAllClose(-np.log(90.), - run(bijector.inverse_log_det_jacobian, x)) self.assertAllClose( - run(bijector.inverse_log_det_jacobian, x), - run(bijector_ref.inverse_log_det_jacobian, x)) + -np.log(90.), + run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) + self.assertAllClose( + run(bijector.inverse_log_det_jacobian, x, event_ndims=1), + run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateRaisesWhenSingular(self): with self.test_session(): @@ -647,38 +469,6 @@ class AffineBijectorTest(test.TestCase): with self.assertRaisesOpError("diagonal part must be non-zero"): bijector.forward([1., 1.]).eval() - def testEventNdimsLargerThanOneRaises(self): - with self.test_session(): - mu = [1., -1] - with self.assertRaisesRegexp( - ValueError, (r"event_ndims\(2\) was not 0 or 1")): - # Scale corresponds to 2x2 identity matrix. - bijector = Affine(shift=mu, event_ndims=2, validate_args=True) - bijector.forward([1., 1.]).eval() - - def testScaleZeroScalarRaises(self): - with self.test_session(): - mu = -1. - # Check Identity matrix with zero scaling. - bijector = Affine( - shift=mu, - scale_identity_multiplier=0., - event_ndims=0, - validate_args=True) - with self.assertRaisesOpError("identity_multiplier should be non-zero"): - bijector.forward(1.).eval() - - def testScaleDiagAndEventNdimsZeroRaises(self): - # Check Diag matrix with zero scaling. - with self.assertRaisesRegexp(ValueError, "only scale argument"): - Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True) - - def testScalarCongruency(self): - with self.test_session(): - bijector = Affine( - shift=3.6, scale_identity_multiplier=0.42, event_ndims=0) - assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) - def _makeScale(self, x, scale_identity_multiplier=None, @@ -747,14 +537,12 @@ class AffineBijectorTest(test.TestCase): scale_args = dict({"x": x}, **args) scale = self._makeScale(**scale_args) - bijector_args = dict({"event_ndims": 1}, **args) - # We haven't specified enough information for the scale. if scale is None: with self.assertRaisesRegexp(ValueError, ("must be specified.")): - bijector = Affine(shift=shift, **bijector_args) + bijector = Affine(shift=shift, **args) else: - bijector = Affine(shift=shift, **bijector_args) + bijector = Affine(shift=shift, **args) np_x = x # For the case a vector is passed in, we need to make the shape # match the matrix for matmul to work. @@ -771,6 +559,7 @@ class AffineBijectorTest(test.TestCase): backward = np.squeeze(backward, axis=-1) self.assertAllClose(backward, bijector.inverse(x).eval()) + scale *= np.ones(shape=x.shape[:-1], dtype=scale.dtype) ildj = -np.log(np.abs(np.linalg.det(scale))) # TODO(jvdillon): We need to make it so the scale_identity_multiplier # case does not deviate in expected shape. Fixing this will get rid of @@ -781,7 +570,8 @@ class AffineBijectorTest(test.TestCase): ildj = np.squeeze(ildj[0]) elif ildj.ndim < scale.ndim - 2: ildj = np.reshape(ildj, scale.shape[0:-2]) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(x).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian(x, event_ndims=1).eval()) def testLegalInputs(self): self._testLegalInputs( @@ -829,15 +619,5 @@ class AffineBijectorTest(test.TestCase): x=np.array( [1., 2], dtype=np.float32)) - def testScalarEventIdentityScale(self): - with self.test_session() as sess: - doubler = Affine( - scale_identity_multiplier=2., - event_ndims=0) - doubler2 = doubler.inverse_log_det_jacobian(2.) - doubler2_ildj_ = sess.run([doubler2]) - self.assertAllClose([-np.log(2.)], doubler2_ildj_) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c832fcaa686c92f83810e4f99ca3b23ae694b723 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py @@ -0,0 +1,237 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for BatchNorm Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import distributions +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import BatchNormalization +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test +from tensorflow.python.training import adam + + +class BatchNormTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + def _reduction_axes(self, input_shape, event_dims): + if isinstance(event_dims, int): + event_dims = [event_dims] + ndims = len(input_shape) + # Convert event_dims to non-negative indexing. + event_dims = list(event_dims) + for idx, x in enumerate(event_dims): + if x < 0: + event_dims[idx] = ndims + x + return tuple(i for i in range(ndims) if i not in event_dims) + + def testForwardInverse(self): + """Tests forward and backward passes with different event shapes. + + input_shape: Tuple of shapes for input tensor. + event_dims: Tuple of dimension indices that will be normalized. + training: Boolean of whether bijector runs in training or inference mode. + """ + params = [ + ((5*2, 4), [-1], False), + ((5, 2, 4), [-1], False), + ((5, 2, 4), [1, 2], False), + ((5, 2, 4), [0, 1], False), + ((5*2, 4), [-1], True), + ((5, 2, 4), [-1], True), + ((5, 2, 4), [1, 2], True), + ((5, 2, 4), [0, 1], True) + ] + for input_shape, event_dims, training in params: + x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape) + with self.test_session() as sess: + x = constant_op.constant(x_) + # When training, memorize the exact mean of the last + # minibatch that it normalized (instead of moving average assignment). + layer = normalization.BatchNormalization( + axis=event_dims, momentum=0., epsilon=0.) + batch_norm = BatchNormalization( + batchnorm_layer=layer, training=training) + # Minibatch statistics are saved only after norm_x has been computed. + norm_x = batch_norm.inverse(x) + with ops.control_dependencies(batch_norm.batchnorm.updates): + moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) + moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) + denorm_x = batch_norm.forward(array_ops.identity(norm_x)) + fldj = batch_norm.forward_log_det_jacobian( + x, event_ndims=len(event_dims)) + # Use identity to invalidate cache. + ildj = batch_norm.inverse_log_det_jacobian( + array_ops.identity(denorm_x), event_ndims=len(event_dims)) + variables.global_variables_initializer().run() + # Update variables. + norm_x_ = sess.run(norm_x) + [ + norm_x_, + moving_mean_, + moving_var_, + denorm_x_, + ildj_, + fldj_, + ] = sess.run([ + norm_x, + moving_mean, + moving_var, + denorm_x, + ildj, + fldj, + ]) + self.assertEqual("batch_normalization", batch_norm.name) + + reduction_axes = self._reduction_axes(input_shape, event_dims) + keepdims = len(event_dims) > 1 + + expected_batch_mean = np.mean( + x_, axis=reduction_axes, keepdims=keepdims) + expected_batch_var = np.var(x_, axis=reduction_axes, keepdims=keepdims) + + if training: + # When training=True, values become normalized across batch dim and + # original values are recovered after de-normalizing. + zeros = np.zeros_like(norm_x_) + self.assertAllClose(np.mean(zeros, axis=reduction_axes), + np.mean(norm_x_, axis=reduction_axes)) + + self.assertAllClose(expected_batch_mean, moving_mean_) + self.assertAllClose(expected_batch_var, moving_var_) + self.assertAllClose(x_, denorm_x_, atol=1e-5) + # Since moving statistics are set to batch statistics after + # normalization, ildj and -fldj should match. + self.assertAllClose(ildj_, -fldj_) + # ildj is computed with minibatch statistics. + expected_ildj = np.sum(np.log(1.) - .5 * np.log( + expected_batch_var + batch_norm.batchnorm.epsilon)) + self.assertAllClose(expected_ildj, ildj_) + else: + # When training=False, moving_mean, moving_var remain at their + # initialized values (0., 1.), resulting in no scale/shift (a small + # shift occurs if epsilon > 0.) + self.assertAllClose(x_, norm_x_) + self.assertAllClose(x_, denorm_x_, atol=1e-5) + # ildj is computed with saved statistics. + expected_ildj = np.sum( + np.log(1.) - .5 * np.log(1. + batch_norm.batchnorm.epsilon)) + self.assertAllClose(expected_ildj, ildj_) + + def testMaximumLikelihoodTraining(self): + # Test Maximum Likelihood training with default bijector. + with self.test_session() as sess: + base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) + batch_norm = BatchNormalization(training=True) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=base_dist, + bijector=batch_norm) + target_dist = distributions.MultivariateNormalDiag(loc=[1., 2.]) + target_samples = target_dist.sample(100) + dist_samples = dist.sample(3000) + loss = -math_ops.reduce_mean(dist.log_prob(target_samples)) + with ops.control_dependencies(batch_norm.batchnorm.updates): + train_op = adam.AdamOptimizer(1e-2).minimize(loss) + moving_mean = array_ops.identity(batch_norm.batchnorm.moving_mean) + moving_var = array_ops.identity(batch_norm.batchnorm.moving_variance) + variables.global_variables_initializer().run() + for _ in range(3000): + sess.run(train_op) + [ + dist_samples_, + moving_mean_, + moving_var_ + ] = sess.run([ + dist_samples, + moving_mean, + moving_var + ]) + self.assertAllClose([1., 2.], np.mean(dist_samples_, axis=0), atol=5e-2) + self.assertAllClose([1., 2.], moving_mean_, atol=5e-2) + self.assertAllClose([1., 1.], moving_var_, atol=5e-2) + + def testLogProb(self): + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) + base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=base_dist, + bijector=batch_norm, + validate_args=True) + samples = dist.sample(int(1e5)) + # No volume distortion since training=False, bijector is initialized + # to the identity transformation. + base_log_prob = base_dist.log_prob(samples) + dist_log_prob = dist.log_prob(samples) + variables.global_variables_initializer().run() + base_log_prob_, dist_log_prob_ = sess.run([base_log_prob, dist_log_prob]) + self.assertAllClose(base_log_prob_, dist_log_prob_) + + def testMutuallyConsistent(self): + # BatchNorm bijector is only mutually consistent when training=False. + dims = 4 + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=batch_norm, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=2., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + # BatchNorm bijector is only mutually consistent when training=False. + dims = 4 + with self.test_session() as sess: + layer = normalization.BatchNormalization(epsilon=0.) + batch_norm = Invert( + BatchNormalization(batchnorm_layer=layer, training=False)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=batch_norm, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=2., + center=0., + rtol=0.02) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index 20e754308449af3f0399101f4ea1bb47b3356424..ca20442c3940664feab7526110229872a6cdc41f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -20,21 +20,33 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test +class ShapeChanging(bijector.Bijector): + """Only used for op_ndims manipulation.""" + + def __init__(self, forward_min_event_ndims=0, inverse_min_event_ndims=3): + super(ShapeChanging, self).__init__( + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, + validate_args=False, name="shape_changer") + + class ChainBijectorTest(test.TestCase): """Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation.""" def testBijector(self): with self.test_session(): - chain = Chain((Exp(event_ndims=1), Softplus(event_ndims=1))) + chain = Chain((Exp(), Softplus())) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], [2., 3.]]]) @@ -42,9 +54,10 @@ class ChainBijectorTest(test.TestCase): self.assertAllClose(np.log(x - 1.), chain.inverse(x).eval()) self.assertAllClose( -np.sum(np.log(x - 1.), axis=2), - chain.inverse_log_det_jacobian(x).eval()) + chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - np.sum(x, axis=2), chain.forward_log_det_jacobian(x).eval()) + np.sum(x, axis=2), + chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testBijectorIdentity(self): with self.test_session(): @@ -54,33 +67,126 @@ class ChainBijectorTest(test.TestCase): [2., 3.]]]) self.assertAllClose(x, chain.forward(x).eval()) self.assertAllClose(x, chain.inverse(x).eval()) - self.assertAllClose(0., chain.inverse_log_det_jacobian(x).eval()) - self.assertAllClose(0., chain.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + 0., chain.inverse_log_det_jacobian(x, event_ndims=1).eval()) + self.assertAllClose( + 0., chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): - bijector = Chain((Exp(), Softplus())) + chain = Chain((Exp(), Softplus())) assert_scalar_congruency( - bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + chain, lower_x=1e-3, upper_x=1.5, rtol=0.05) def testShapeGetters(self): with self.test_session(): - bijector = Chain([ - SoftmaxCentered( - event_ndims=1, validate_args=True), - SoftmaxCentered( - event_ndims=0, validate_args=True) + chain = Chain([ + SoftmaxCentered(validate_args=True), + SoftmaxCentered(validate_args=True), ]) - x = tensor_shape.TensorShape([]) + x = tensor_shape.TensorShape([1]) y = tensor_shape.TensorShape([2 + 1]) - self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y, chain.forward_event_shape(x)) self.assertAllEqual( y.as_list(), - bijector.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) + chain.forward_event_shape_tensor(x.as_list()).eval()) + self.assertAllEqual(x, chain.inverse_event_shape(y)) self.assertAllEqual( x.as_list(), - bijector.inverse_event_shape_tensor(y.as_list()).eval()) + chain.inverse_event_shape_tensor(y.as_list()).eval()) + + def testMinEventNdimsChain(self): + chain = Chain([Exp(), Exp(), Exp()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Affine(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Exp(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Exp()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), Exp(), Softplus(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingAddDims(self): + chain = Chain([ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(3, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(), Affine()]) + self.assertEqual(1, chain.forward_min_event_ndims) + self.assertEqual(4, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(3, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(), ShapeChanging()]) + self.assertEqual(0, chain.forward_min_event_ndims) + self.assertEqual(6, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingRemoveDims(self): + chain = Chain([ShapeChanging(3, 0)]) + self.assertEqual(3, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(3, 0), Affine()]) + self.assertEqual(3, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + chain = Chain([Affine(), ShapeChanging(3, 0)]) + self.assertEqual(4, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + chain = Chain([ShapeChanging(3, 0), ShapeChanging(3, 0)]) + self.assertEqual(6, chain.forward_min_event_ndims) + self.assertEqual(0, chain.inverse_min_event_ndims) + + def testMinEventNdimsShapeChangingAddRemoveDims(self): + chain = Chain([ + ShapeChanging(2, 1), + ShapeChanging(3, 0), + ShapeChanging(1, 2)]) + self.assertEqual(4, chain.forward_min_event_ndims) + self.assertEqual(1, chain.inverse_min_event_ndims) + + def testChainExpAffine(self): + scale_diag = np.array([1., 2., 3.], dtype=np.float32) + chain = Chain([Exp(), Affine(scale_diag=scale_diag)]) + x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] + y = [1., 4., 27.] + self.assertAllClose(y, self.evaluate(chain.forward(x))) + self.assertAllClose(x, self.evaluate(chain.inverse(y))) + self.assertAllClose( + np.log(6, dtype=np.float32) + np.sum(scale_diag * x), + self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) + + self.assertAllClose( + -np.log(6, dtype=np.float32) - np.sum(scale_diag * x), + self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + + def testChainAffineExp(self): + scale_diag = np.array([1., 2., 3.], dtype=np.float32) + chain = Chain([Affine(scale_diag=scale_diag), Exp()]) + x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] + y = [1., 4., 9.] + self.assertAllClose(y, self.evaluate(chain.forward(x))) + self.assertAllClose(x, self.evaluate(chain.inverse(y))) + self.assertAllClose( + np.log(6, dtype=np.float32) + np.sum(x), + self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1))) + + self.assertAllClose( + -np.log(6, dtype=np.float32) - np.sum(x), + self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index 0ff35304283fce9ce3f9e5d31b1258394e384d7b..e281e81bdf0698c1f7b2f60fb27783dd1351773f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -18,70 +18,114 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import bijectors -from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import gamma as gamma_lib -from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib -from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test -class InvertBijectorTest(test.TestCase): - """Tests the correctness of the Y = Invert(bij) transformation.""" +class CholeskyOuterProductBijectorTest(test.TestCase): + """Tests the correctness of the Y = X @ X.T transformation.""" - def testBijector(self): + def testBijectorMatrix(self): with self.test_session(): - for fwd in [ - bijectors.Identity(), - bijectors.Exp(event_ndims=1), - bijectors.Affine( - shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1), - bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), - ]: - rev = bijectors.Invert(fwd) - self.assertEqual("_".join(["invert", fwd.name]), rev.name) - x = [[[1., 2.], - [2., 3.]]] - self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval()) - self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval()) - self.assertAllClose( - fwd.forward_log_det_jacobian(x).eval(), - rev.inverse_log_det_jacobian(x).eval()) - self.assertAllClose( - fwd.inverse_log_det_jacobian(x).eval(), - rev.forward_log_det_jacobian(x).eval()) + bijector = bijectors.CholeskyOuterProduct(validate_args=True) + self.assertEqual("cholesky_outer_product", bijector.name) + x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]] + y = np.matmul(x, np.transpose(x, axes=(0, 2, 1))) + # Fairly easy to compute differentials since we have 2x2. + dx_dy = [[[2. * 1, 0, 0], + [2, 1, 0], + [0, 2 * 2, 2 * 1]], + [[2 * np.sqrt(2.), 0, 0], + [np.sqrt(8.), np.sqrt(2.), 0], + [0, 2 * np.sqrt(8.), 2 * 1]]] + ildj = -np.sum( + np.log(np.asarray(dx_dy).diagonal( + offset=0, axis1=1, axis2=2)), + axis=1) + self.assertAllEqual((2, 2, 2), bijector.forward(x).get_shape()) + self.assertAllEqual((2, 2, 2), bijector.inverse(y).get_shape()) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=2).eval(), atol=0., rtol=1e-7) + self.assertAllClose( + -bijector.inverse_log_det_jacobian( + y, event_ndims=2).eval(), + bijector.forward_log_det_jacobian( + x, event_ndims=2).eval(), + atol=0., + rtol=1e-7) - def testScalarCongruency(self): - with self.test_session(): - bijector = bijectors.Invert(bijectors.Exp()) - assert_scalar_congruency( - bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + def testNoBatchStatic(self): + x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) + y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) + with self.test_session() as sess: + y_actual = bijectors.CholeskyOuterProduct().forward(x=x) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) + self.assertAllEqual([2, 2], y_actual.get_shape()) + self.assertAllEqual([2, 2], x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) - def testShapeGetters(self): - with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) - x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) - self.assertAllEqual(y, bijector.forward_event_shape(x)) - self.assertAllEqual( - y.as_list(), - bijector.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) - self.assertAllEqual( - x.as_list(), - bijector.inverse_event_shape_tensor(y.as_list()).eval()) + def testNoBatchDeferred(self): + x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) + y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtypes.float32) + y_pl = array_ops.placeholder(dtypes.float32) + y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual], + feed_dict={x_pl: x, y_pl: y}) + self.assertEqual(None, y_actual.get_shape()) + self.assertEqual(None, x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) - def testDocstringExample(self): - with self.test_session(): - exp_gamma_distribution = ( - transformed_distribution_lib.TransformedDistribution( - distribution=gamma_lib.Gamma(concentration=1., rate=2.), - bijector=bijectors.Invert(bijectors.Exp()))) - self.assertAllEqual( - [], array_ops.shape(exp_gamma_distribution.sample()).eval()) + def testBatchStatic(self): + x = np.array([[[1., 0], + [2, 1]], + [[3., 0], + [1, 2]]]) # np.linalg.cholesky(y) + y = np.array([[[1., 2], + [2, 5]], + [[9., 3], + [3, 5]]]) # np.matmul(x, x.T) + with self.test_session() as sess: + y_actual = bijectors.CholeskyOuterProduct().forward(x=x) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) + self.assertEqual([2, 2, 2], y_actual.get_shape()) + self.assertEqual([2, 2, 2], x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) + + def testBatchDeferred(self): + x = np.array([[[1., 0], + [2, 1]], + [[3., 0], + [1, 2]]]) # np.linalg.cholesky(y) + y = np.array([[[1., 2], + [2, 5]], + [[9., 3], + [3, 5]]]) # np.matmul(x, x.T) + with self.test_session() as sess: + x_pl = array_ops.placeholder(dtypes.float32) + y_pl = array_ops.placeholder(dtypes.float32) + y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) + x_actual = bijectors.CholeskyOuterProduct().inverse(y=y_pl) + [y_actual_, x_actual_] = sess.run([y_actual, x_actual], + feed_dict={x_pl: x, y_pl: y}) + self.assertEqual(None, y_actual.get_shape()) + self.assertEqual(None, x_actual.get_shape()) + self.assertAllClose(y, y_actual_) + self.assertAllClose(x, x_actual_) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 26e0d2a539c78540603281ae0f361987a7bf8d90..8b279ebcd908b6f375b35594ac5f3db9228a1e31 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -27,7 +27,7 @@ class _TestBijector(ConditionalBijector): def __init__(self): super(_TestBijector, self).__init__( - event_ndims=0, + forward_min_event_ndims=0, graph_parents=[], is_constant_jacobian=True, validate_args=False, @@ -51,11 +51,15 @@ class ConditionalBijectorTest(test.TestCase): def testConditionalBijector(self): b = _TestBijector() - for name in ["forward", "inverse", "inverse_log_det_jacobian", - "forward_log_det_jacobian"]: + for name in ["forward", "inverse"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1.0, arg1="b1", arg2="b2") + method(1., arg1="b1", arg2="b2") + + for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: + method = getattr(b, name) + with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): + method(1., event_ndims=0., arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py index 9970c0b4d86afda188d9401ebaf3c98d3fffbfdf..7be939cd274e6f0e33c9b01c82494755db2caa73 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py @@ -31,17 +31,21 @@ class ExpBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - bijector = Exp(event_ndims=1) + bijector = Exp() self.assertEqual("exp", bijector.name) x = [[[1.], [2.]]] y = np.exp(x) self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - -np.sum(np.log(y), axis=-1), - bijector.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-bijector.inverse_log_det_jacobian(np.exp(x)).eval(), - bijector.forward_log_det_jacobian(x).eval()) + -np.squeeze(np.log(y), axis=-1), + bijector.inverse_log_det_jacobian( + y, event_ndims=1).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian( + np.exp(x), event_ndims=1).eval(), + bijector.forward_log_det_jacobian( + x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): @@ -51,10 +55,10 @@ class ExpBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): - bijector = Exp(event_ndims=0) + bijector = Exp() x = np.linspace(-10, 10, num=10).astype(np.float32) y = np.logspace(-10, 10, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y) + assert_bijective_and_finite(bijector, x, y, event_ndims=0) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py index 9a905980c7581a86bbcda8c6c726da57c09fe4f8..54e54c3296a89a4fe29a3cce971760502b65e784 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -34,7 +34,7 @@ class GumbelBijectorTest(test.TestCase): with self.test_session(): loc = 0.3 scale = 5. - bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True) + bijector = Gumbel(loc=loc, scale=scale, validate_args=True) self.assertEqual("gumbel", bijector.name) x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32) # Gumbel distribution @@ -43,13 +43,11 @@ class GumbelBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - np.squeeze(gumbel_dist.logpdf(x), axis=2), - bijector.forward_log_det_jacobian(x).eval()) + np.squeeze(gumbel_dist.logpdf(x), axis=-1), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -60,10 +58,10 @@ class GumbelBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): - bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True) + bijector = Gumbel(loc=0., scale=3.0, validate_args=True) x = np.linspace(-10., 10., num=10).astype(np.float32) y = np.linspace(0.01, 0.99, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py index 739fa6d439a8bce993ab1b4601489d9bbcd69bee..7d3bd758cd2db307f95d2d934923ea2133dc1217 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py @@ -33,15 +33,13 @@ class InlineBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): - exp = Exp(event_ndims=1) + exp = Exp() inline = Inline( forward_fn=math_ops.exp, inverse_fn=math_ops.log, - inverse_log_det_jacobian_fn=( - lambda y: -math_ops.reduce_sum( # pylint: disable=g-long-lambda - math_ops.log(y), reduction_indices=-1)), - forward_log_det_jacobian_fn=( - lambda x: math_ops.reduce_sum(x, reduction_indices=-1)), + inverse_log_det_jacobian_fn=lambda y: -math_ops.log(y), + forward_log_det_jacobian_fn=lambda x: x, + forward_min_event_ndims=0, name="exp") self.assertEqual(exp.name, inline.name) @@ -51,9 +49,10 @@ class InlineBijectorTest(test.TestCase): self.assertAllClose(x, inline.inverse(y).eval()) self.assertAllClose( -np.sum(np.log(y), axis=-1), - inline.inverse_log_det_jacobian(y).eval()) - self.assertAllClose(-inline.inverse_log_det_jacobian(y).eval(), - inline.forward_log_det_jacobian(x).eval()) + inline.inverse_log_det_jacobian(y, event_ndims=1).eval()) + self.assertAllClose( + -inline.inverse_log_det_jacobian(y, event_ndims=1).eval(), + inline.forward_log_det_jacobian(x, event_ndims=1).eval()) def testShapeGetters(self): with self.test_session(): @@ -62,6 +61,7 @@ class InlineBijectorTest(test.TestCase): forward_event_shape_fn=lambda x: x.as_list() + [1], inverse_event_shape_tensor_fn=lambda x: x[:-1], inverse_event_shape_fn=lambda x: x[:-1], + forward_min_event_ndims=0, name="shape_only") x = tensor_shape.TensorShape([1, 2, 3]) y = tensor_shape.TensorShape([1, 2, 3, 1]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index 0ff35304283fce9ce3f9e5d31b1258394e384d7b..8b14c8327f08902044f50483f9f8dfe67b58cd70 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -34,12 +34,10 @@ class InvertBijectorTest(test.TestCase): with self.test_session(): for fwd in [ bijectors.Identity(), - bijectors.Exp(event_ndims=1), - bijectors.Affine( - shift=[0., 1.], scale_diag=[2., 3.], event_ndims=1), - bijectors.Softplus(event_ndims=1), - bijectors.SoftmaxCentered(event_ndims=1), - bijectors.SigmoidCentered(), + bijectors.Exp(), + bijectors.Affine(shift=[0., 1.], scale_diag=[2., 3.]), + bijectors.Softplus(), + bijectors.SoftmaxCentered(), ]: rev = bijectors.Invert(fwd) self.assertEqual("_".join(["invert", fwd.name]), rev.name) @@ -48,11 +46,11 @@ class InvertBijectorTest(test.TestCase): self.assertAllClose(fwd.inverse(x).eval(), rev.forward(x).eval()) self.assertAllClose(fwd.forward(x).eval(), rev.inverse(x).eval()) self.assertAllClose( - fwd.forward_log_det_jacobian(x).eval(), - rev.inverse_log_det_jacobian(x).eval()) + fwd.forward_log_det_jacobian(x, event_ndims=1).eval(), + rev.inverse_log_det_jacobian(x, event_ndims=1).eval()) self.assertAllClose( - fwd.inverse_log_det_jacobian(x).eval(), - rev.forward_log_det_jacobian(x).eval()) + fwd.inverse_log_det_jacobian(x, event_ndims=1).eval(), + rev.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): @@ -62,9 +60,9 @@ class InvertBijectorTest(test.TestCase): def testShapeGetters(self): with self.test_session(): - bijector = bijectors.Invert(bijectors.SigmoidCentered(validate_args=True)) + bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True)) x = tensor_shape.TensorShape([2]) - y = tensor_shape.TensorShape([]) + y = tensor_shape.TensorShape([1]) self.assertAllEqual(y, bijector.forward_event_shape(x)) self.assertAllEqual( y.as_list(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a8089881f684db9f8876d6dd738e52bf2f1f7606 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Kumaraswamy Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class KumaraswamyBijectorTest(test.TestCase): + """Tests correctness of the Kumaraswamy bijector.""" + + def testBijector(self): + with self.test_session(): + a = 2. + b = 0.3 + bijector = Kumaraswamy( + concentration1=a, concentration0=b, validate_args=True) + self.assertEqual("kumaraswamy", bijector.name) + x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32) + # Kumaraswamy cdf. This is the same as inverse(x). + y = 1. - (1. - x ** a) ** b + self.assertAllClose(y, bijector.inverse(x).eval()) + self.assertAllClose(x, bijector.forward(y).eval()) + kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) + + (b - 1) * np.log1p(-x ** a)) + + self.assertAllClose( + np.squeeze(kumaraswamy_log_pdf, axis=-1), + bijector.inverse_log_det_jacobian(x, event_ndims=1).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(x, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(y, event_ndims=1).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Kumaraswamy(concentration1=0.5, concentration0=1.1), + lower_x=0., upper_x=1., n=int(10e3), rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + concentration1 = 1.2 + concentration0 = 2. + bijector = Kumaraswamy( + concentration1=concentration1, + concentration0=concentration0, validate_args=True) + # Omitting the endpoints 0 and 1, since idlj will be infinity at these + # endpoints. + y = np.linspace(.01, 0.99, num=10).astype(np.float32) + x = 1 - (1 - y ** concentration1) ** concentration0 + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index dcfb0eb05185d36d96947905c2eb91b2201aece1..5ba5a2083bf11791d7d58146dc2e6283b524d241 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -79,9 +79,10 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, forward_x = ma.forward(x) # Use identity to invalidate cache. inverse_y = ma.inverse(array_ops.identity(forward_x)) - fldj = ma.forward_log_det_jacobian(x) + fldj = ma.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. - ildj = ma.inverse_log_det_jacobian(array_ops.identity(forward_x)) + ildj = ma.inverse_log_det_jacobian( + array_ops.identity(forward_x), event_ndims=1) variables.global_variables_initializer().run() [ forward_x_, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py index 54590de373441c32cc3214cb04d45cfc2d1807ed..7eef4ab599951bbb624652f13a0091363b36b93d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -53,8 +53,8 @@ class PermuteBijectorTest(test.TestCase): bijector.permutation, bijector.inverse(expected_y), bijector.forward(expected_x), - bijector.forward_log_det_jacobian(expected_x), - bijector.inverse_log_det_jacobian(expected_y), + bijector.forward_log_det_jacobian(expected_x, event_ndims=1), + bijector.inverse_log_det_jacobian(expected_y, event_ndims=1), ], feed_dict={permutation_ph: expected_permutation}) self.assertEqual("permute", bijector.name) self.assertAllEqual(expected_permutation, permutation_) @@ -78,10 +78,9 @@ class PermuteBijectorTest(test.TestCase): x = np.random.randn(4, 2, 3) y = x[..., permutation] with self.test_session(): - bijector = Permute( - permutation=permutation, - validate_args=True) - assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + bijector = Permute(permutation=permutation, validate_args=True) + assert_bijective_and_finite( + bijector, x, y, event_ndims=1, rtol=1e-6, atol=0) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py index de1659aa9f4d0f7d19ec2e8185715573b78eaf2b..85d22830132816cd6c77cd0b07870f3a22ae9798 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py @@ -32,8 +32,7 @@ class PowerTransformBijectorTest(test.TestCase): def testBijector(self): with self.test_session(): c = 0.2 - bijector = PowerTransform( - power=c, event_ndims=1, validate_args=True) + bijector = PowerTransform(power=c, validate_args=True) self.assertEqual("power_transform", bijector.name) x = np.array([[[-1.], [2.], [-5. + 1e-4]]]) y = (1. + x * c)**(1. / c) @@ -41,27 +40,25 @@ class PowerTransformBijectorTest(test.TestCase): self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( (c - 1.) * np.sum(np.log(y), axis=-1), - bijector.inverse_log_det_jacobian(y).eval()) + bijector.inverse_log_det_jacobian(y, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) def testScalarCongruency(self): with self.test_session(): - bijector = PowerTransform( - power=0.2, validate_args=True) + bijector = PowerTransform(power=0.2, validate_args=True) assert_scalar_congruency( bijector, lower_x=-2., upper_x=1.5, rtol=0.05) def testBijectiveAndFinite(self): with self.test_session(): - bijector = PowerTransform( - power=0.2, event_ndims=0, validate_args=True) + bijector = PowerTransform(power=0.2, validate_args=True) x = np.linspace(-4.999, 10, num=10).astype(np.float32) y = np.logspace(0.001, 10, num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py index 46fe7797419a9906ecdad60dd0dfe1e9d7c743ed..2d52895fbe0967cdd2260d6d298a291286858d09 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py @@ -52,24 +52,28 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): forward_x = nvp.forward(x) # Use identity to invalidate cache. inverse_y = nvp.inverse(array_ops.identity(forward_x)) - fldj = nvp.forward_log_det_jacobian(x) + forward_inverse_y = nvp.forward(inverse_y) + fldj = nvp.forward_log_det_jacobian(x, event_ndims=1) # Use identity to invalidate cache. - ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x)) + ildj = nvp.inverse_log_det_jacobian( + array_ops.identity(forward_x), event_ndims=1) variables.global_variables_initializer().run() [ forward_x_, inverse_y_, + forward_inverse_y_, ildj_, fldj_, ] = sess.run([ forward_x, inverse_y, + forward_inverse_y, ildj, fldj, ]) self.assertEqual("real_nvp", nvp.name) - self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) - self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(forward_x_, forward_inverse_y_, rtol=1e-1, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-1, atol=0.) self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) def testMutuallyConsistent(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index e216d88cb190dc16fc0056186f80817d6f2d7c67..46f2c63f9b0f78b25bb1948e6ea55ab20c5cfa6e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -65,8 +65,8 @@ class _ReshapeBijectorTest(object): ildj_) = sess.run(( bijector.inverse(expected_y), bijector.forward(expected_x), - bijector.forward_log_det_jacobian(expected_x), - bijector.inverse_log_det_jacobian(expected_y), + bijector.forward_log_det_jacobian(expected_x, event_ndims=2), + bijector.inverse_log_det_jacobian(expected_y, event_ndims=2), ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) @@ -301,7 +301,8 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): event_shape_in=[2, 3], event_shape_out=[1, 2, 3], validate_args=True) - assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + assert_bijective_and_finite( + bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): if ops._USE_C_API: diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py deleted file mode 100644 index 4ff3f334ccb59f1c117b3d35032d9e799cfd79bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_centered_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import SigmoidCentered -from tensorflow.python.platform import test - - -class SigmoidCenteredBijectorTest(test.TestCase): - """Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation.""" - - def testBijector(self): - with self.test_session(): - sigmoid = SigmoidCentered() - self.assertEqual("sigmoid_centered", sigmoid.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] - self.assertAllClose(y, sigmoid.forward(x).eval()) - self.assertAllClose(x, sigmoid.inverse(y).eval()) - self.assertAllClose( - -np.sum(np.log(y), axis=2), - sigmoid.inverse_log_det_jacobian(y).eval(), - atol=0., - rtol=1e-7) - self.assertAllClose( - -sigmoid.inverse_log_det_jacobian(y).eval(), - sigmoid.forward_log_det_jacobian(x).eval(), - atol=0., - rtol=1e-7) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index e4f9d72785c301284812a48c0a67614ca439ffae..cea4a62c22af5d98d38ee881b29c773e6a27a4b4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -36,12 +36,13 @@ class SigmoidBijectorTest(test.TestCase): x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32) y = special.expit(x) ildj = -np.log(y) - np.log1p(-y) - self.assertAllClose(y, Sigmoid().forward(x).eval(), atol=0., rtol=1e-2) - self.assertAllClose(x, Sigmoid().inverse(y).eval(), atol=0., rtol=1e-4) - self.assertAllClose(ildj, Sigmoid().inverse_log_det_jacobian(y).eval(), - atol=0., rtol=1e-6) - self.assertAllClose(-ildj, Sigmoid().forward_log_det_jacobian(x).eval(), - atol=0., rtol=1e-4) + bijector = Sigmoid() + self.assertAllClose(y, bijector.forward(x).eval(), atol=0., rtol=1e-2) + self.assertAllClose(x, bijector.inverse(y).eval(), atol=0., rtol=1e-4) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval(), atol=0., rtol=1e-6) + self.assertAllClose(-ildj, bijector.forward_log_det_jacobian( + x, event_ndims=0).eval(), atol=0., rtol=1e-4) def testScalarCongruency(self): with self.test_session(): @@ -52,7 +53,8 @@ class SigmoidBijectorTest(test.TestCase): x = np.linspace(-7., 7., 100).astype(np.float32) eps = 1e-3 y = np.linspace(eps, 1. - eps, 100).astype(np.float32) - assert_bijective_and_finite(Sigmoid(), x, y, atol=0., rtol=1e-4) + assert_bijective_and_finite( + Sigmoid(), x, y, event_ndims=0, atol=0., rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 172c180a44229089f06f250a872bc47a89991cf0..45760a29ee42835da69ef63803ccec7ce82a5a8f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -39,7 +39,6 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector = SinhArcsinh( skewness=skewness, tailweight=tailweight, - event_ndims=1, validate_args=True) self.assertEqual("SinhArcsinh", bijector.name) x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32) @@ -50,10 +49,11 @@ class SinhArcsinhBijectorTest(test.TestCase): np.sum( np.log(np.cosh(np.arcsinh(y) / tailweight - skewness)) - np.log(tailweight) - np.log(np.sqrt(y**2 + 1)), - axis=-1), bijector.inverse_log_det_jacobian(y).eval()) + axis=-1), + bijector.inverse_log_det_jacobian(y, event_ndims=1).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=1).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=1).eval(), rtol=1e-4, atol=0.) @@ -106,14 +106,15 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True) x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace( -2, 10, 1000))).astype(np.float32) - assert_bijective_and_finite(bijector, x, x, rtol=1e-3) + assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectiveAndFiniteSkewness1Tailweight3(self): with self.test_session(): bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True) x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace( -2, 5, 1000))).astype(np.float32) - assert_bijective_and_finite(bijector, x, x, rtol=1e-3) + assert_bijective_and_finite( + bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectorEndpoints(self): with self.test_session(): @@ -124,7 +125,8 @@ class SinhArcsinhBijectorTest(test.TestCase): [np.finfo(dtype).min, np.finfo(dtype).max], dtype=dtype) # Note that the above bijector is the identity bijector. Hence, the # log_det_jacobian will be 0. Because of this we use atol. - assert_bijective_and_finite(bijector, bounds, bounds, atol=2e-6) + assert_bijective_and_finite( + bijector, bounds, bounds, event_ndims=0, atol=2e-6) def testBijectorOverRange(self): with self.test_session(): @@ -156,12 +158,12 @@ class SinhArcsinhBijectorTest(test.TestCase): np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( y_float128**2 + 1)) - np.log(tailweight), - bijector.inverse_log_det_jacobian(y).eval(), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), rtol=1e-4, atol=0.) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), rtol=1e-4, atol=0.) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 62e3869db090e9c9327bc552d10234ff76ba28fd..0f0a2fa531a0585a709df4c2c3e2631e5c275986 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test @@ -32,70 +34,68 @@ rng = np.random.RandomState(42) class SoftmaxCenteredBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation.""" - def testBijectorScalar(self): + def testBijectorVector(self): with self.test_session(): - softmax = SoftmaxCentered() # scalar by default + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], - [4., 8, 12]]) - y = [[[2. / 3, 1. / 3], - [3. / 4, 1. / 4], - [4. / 5, 1. / 5]], - [[4. / 5, 1. / 5], - [8. / 9, 1. / 9], - [12. / 13, 1. / 13]]] + x = np.log([[2., 3, 4], [4., 8, 12]]) + y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] self.assertAllClose(y, softmax.forward(x).eval()) self.assertAllClose(x, softmax.inverse(y).eval()) self.assertAllClose( - -np.sum(np.log(y), axis=2), - softmax.inverse_log_det_jacobian(y).eval(), + -np.sum(np.log(y), axis=1), + softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), + -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval(), + softmax.forward_log_det_jacobian(x, event_ndims=1).eval(), atol=0., rtol=1e-7) - def testBijectorVector(self): + def testBijectorUnknownShape(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) - x = np.log([[2., 3, 4], [4., 8, 12]]) - y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] - self.assertAllClose(y, softmax.forward(x).eval()) - self.assertAllClose(x, softmax.inverse(y).eval()) + x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_x = np.log([[2., 3, 4], [4., 8, 12]]) + y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_y = [[0.2, 0.3, 0.4, 0.1], [0.16, 0.32, 0.48, 0.04]] + self.assertAllClose(real_y, softmax.forward(x).eval( + feed_dict={x: real_x})) + self.assertAllClose(real_x, softmax.inverse(y).eval( + feed_dict={y: real_y})) self.assertAllClose( - -np.sum(np.log(y), axis=1), - softmax.inverse_log_det_jacobian(y).eval(), + -np.sum(np.log(real_y), axis=1), + softmax.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), atol=0., rtol=1e-7) self.assertAllClose( - -softmax.inverse_log_det_jacobian(y).eval(), - softmax.forward_log_det_jacobian(x).eval(), + -softmax.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), + softmax.forward_log_det_jacobian(x, event_ndims=1).eval( + feed_dict={x: real_x}), atol=0., rtol=1e-7) def testShapeGetters(self): with self.test_session(): - for x, y, b in ((tensor_shape.TensorShape([]), - tensor_shape.TensorShape([2]), - SoftmaxCentered( - event_ndims=0, validate_args=True)), - (tensor_shape.TensorShape([4]), - tensor_shape.TensorShape([5]), - SoftmaxCentered( - event_ndims=1, validate_args=True))): - self.assertAllEqual(y, b.forward_event_shape(x)) - self.assertAllEqual(y.as_list(), - b.forward_event_shape_tensor(x.as_list()).eval()) - self.assertAllEqual(x, b.inverse_event_shape(y)) - self.assertAllEqual(x.as_list(), - b.inverse_event_shape_tensor(y.as_list()).eval()) + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([5]) + bijector = SoftmaxCentered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + bijector.forward_event_shape_tensor( + x.as_list()).eval()) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + bijector.inverse_event_shape_tensor( + y.as_list()).eval()) def testBijectiveAndFinite(self): with self.test_session(): - softmax = SoftmaxCentered(event_ndims=1) + softmax = SoftmaxCentered() x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32) # Make y values on the simplex with a wide range. y_0 = np.ones(5).astype(np.float32) @@ -104,7 +104,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): y = np.array([y_0, y_1, y_2]) y /= y.sum(axis=0) y = y.T # y.shape = [5, 3] - assert_bijective_and_finite(softmax, x, y) + assert_bijective_and_finite(softmax, x, y, event_ndims=1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index d9af9aec50d3d69bb10f69f2ffd6ca3a24c316f8..3d8a0a32bba3539f732140e8eb7ebeb532d73ff5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -43,13 +43,13 @@ class SoftplusBijectorTest(test.TestCase): def testHingeSoftnessZeroRaises(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=0., validate_args=True) + bijector = Softplus(hinge_softness=0., validate_args=True) with self.assertRaisesOpError("must be non-zero"): bijector.forward([1., 1.]).eval() def testBijectorForwardInverseEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) @@ -59,7 +59,7 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.5) + bijector = Softplus(hinge_softness=1.5) x = 2 * rng.randn(2, 10) y = 1.5 * self._softplus(x / 1.5) @@ -68,16 +68,17 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorLogDetJacobianEventDimsZero(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() y = 2 * rng.rand(2, 10) # No reduction needed if event_dims = 0. ildj = self._softplus_ildj_before_reduction(y) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval()) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval()) def testBijectorForwardInverseEventDimsOne(self): with self.test_session(): - bijector = Softplus(event_ndims=1) + bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) y = self._softplus(x) @@ -87,58 +88,59 @@ class SoftplusBijectorTest(test.TestCase): def testBijectorLogDetJacobianEventDimsOne(self): with self.test_session(): - bijector = Softplus(event_ndims=1) + bijector = Softplus() y = 2 * rng.rand(2, 10) ildj_before = self._softplus_ildj_before_reduction(y) ildj = np.sum(ildj_before, axis=1) - self.assertAllClose(ildj, bijector.inverse_log_det_jacobian(y).eval()) + self.assertAllClose(ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=1).eval()) def testScalarCongruency(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithPositiveHingeSoftness(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.3) + bijector = Softplus(hinge_softness=1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithNegativeHingeSoftness(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=-1.3) + bijector = Softplus(hinge_softness=-1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testBijectiveAndFinite32bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=1.23) + bijector = Softplus(hinge_softness=1.23) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0, hinge_softness=-0.7) + bijector = Softplus(hinge_softness=-0.7) x = np.linspace(-20., 20., 100).astype(np.float32) y = -np.logspace(-10, 10, 100).astype(np.float32) assert_bijective_and_finite( - bijector, x, y, rtol=1e-2, atol=1e-2) + bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFinite16bit(self): with self.test_session(): - bijector = Softplus(event_ndims=0) + bijector = Softplus() # softplus(-20) is zero, so we can't use such a large range as in 32bit. x = np.linspace(-10., 20., 100).astype(np.float16) # Note that float16 is only in the open set (0, inf) for a smaller @@ -146,7 +148,7 @@ class SoftplusBijectorTest(test.TestCase): # for the test. y = np.logspace(-6, 3, 100).astype(np.float16) assert_bijective_and_finite( - bijector, x, y, rtol=1e-1, atol=1e-3) + bijector, x, y, event_ndims=0, rtol=1e-1, atol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac06fce55b448a5f3da7ccb7f8766b5b1404ad7 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.softsign import Softsign +from tensorflow.python.framework import test_util +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class SoftsignBijectorTest(test.TestCase): + """Tests the correctness of the Y = g(X) = X / (1 + |X|) transformation.""" + + def _softsign(self, x): + return x / (1. + np.abs(x)) + + def _softsign_ildj_before_reduction(self, y): + """Inverse log det jacobian, before being reduced.""" + return -2. * np.log1p(-np.abs(y)) + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorBounds(self): + bijector = Softsign(validate_args=True) + with self.test_session(): + with self.assertRaisesOpError("greater than -1"): + bijector.inverse(-3.).eval() + with self.assertRaisesOpError("greater than -1"): + bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval() + + with self.assertRaisesOpError("less than 1"): + bijector.inverse(3.).eval() + with self.assertRaisesOpError("less than 1"): + bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() + + @test_util.run_in_graph_and_eager_modes() + def testBijectorForwardInverse(self): + bijector = Softsign(validate_args=True) + self.assertEqual("softsign", bijector.name) + x = 2. * self._rng.randn(2, 10) + y = self._softsign(x) + + self.assertAllClose(y, self.evaluate(bijector.forward(x))) + self.assertAllClose(x, self.evaluate(bijector.inverse(y))) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorLogDetJacobianEventDimsZero(self): + bijector = Softsign(validate_args=True) + y = self._rng.rand(2, 10) + # No reduction needed if event_dims = 0. + ildj = self._softsign_ildj_before_reduction(y) + + self.assertAllClose(ildj, self.evaluate( + bijector.inverse_log_det_jacobian(y, event_ndims=0))) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorForwardInverseEventDimsOne(self): + bijector = Softsign(validate_args=True) + self.assertEqual("softsign", bijector.name) + x = 2. * self._rng.randn(2, 10) + y = self._softsign(x) + self.assertAllClose(y, self.evaluate(bijector.forward(x))) + self.assertAllClose(x, self.evaluate(bijector.inverse(y))) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorLogDetJacobianEventDimsOne(self): + bijector = Softsign(validate_args=True) + y = self._rng.rand(2, 10) + ildj_before = self._softsign_ildj_before_reduction(y) + ildj = np.sum(ildj_before, axis=1) + self.assertAllClose( + ildj, self.evaluate( + bijector.inverse_log_det_jacobian(y, event_ndims=1))) + + def testScalarCongruency(self): + with self.test_session(): + bijector = Softsign(validate_args=True) + assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.) + + def testBijectiveAndFinite(self): + with self.test_session(): + bijector = Softsign(validate_args=True) + x = np.linspace(-20., 20., 100).astype(np.float32) + y = np.linspace(-0.99, 0.99, 100).astype(np.float32) + assert_bijective_and_finite( + bijector, x, y, event_ndims=0, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30c7a738c320b609ce90685512e6b8344dffc9dc --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class SquareBijectorTest(test.TestCase): + """Tests the correctness of the Y = X ** 2 transformation.""" + + def testBijectorScalar(self): + with self.test_session(): + bijector = bijectors.Square(validate_args=True) + self.assertEqual("square", bijector.name) + x = [[[1., 5], + [2, 1]], + [[np.sqrt(2.), 3], + [np.sqrt(8.), 1]]] + y = np.square(x) + ildj = -np.log(2.) - np.log(x) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + ildj, bijector.inverse_log_det_jacobian( + y, event_ndims=0).eval(), atol=0., rtol=1e-7) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), + atol=0., + rtol=1e-7) + + def testScalarCongruency(self): + with self.test_session(): + bijector = bijectors.Square(validate_args=True) + assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py index 7a31228d1ade55ce32b511dca073657d3bab53ae..f57adcda898a1fdb18aacbb0804411db1bb4e4c8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py @@ -36,7 +36,7 @@ class WeibullBijectorTest(test.TestCase): concentration = 0.3 bijector = Weibull( scale=scale, concentration=concentration, - event_ndims=1, validate_args=True) + validate_args=True) self.assertEqual("weibull", bijector.name) x = np.array([[[0.], [1.], [14.], [20.], [100.]]], dtype=np.float32) # Weibull distribution @@ -45,13 +45,11 @@ class WeibullBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval()) self.assertAllClose(x, bijector.inverse(y).eval()) self.assertAllClose( - # We should lose a dimension from calculating the determinant of the - # jacobian. - np.squeeze(weibull_dist.logpdf(x), axis=2), - bijector.forward_log_det_jacobian(x).eval()) + weibull_dist.logpdf(x), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval()) self.assertAllClose( - -bijector.inverse_log_det_jacobian(y).eval(), - bijector.forward_log_det_jacobian(x).eval(), + -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), rtol=1e-4, atol=0.) @@ -64,12 +62,12 @@ class WeibullBijectorTest(test.TestCase): def testBijectiveAndFinite(self): with self.test_session(): bijector = Weibull( - scale=20., concentration=2., event_ndims=0, validate_args=True) + scale=20., concentration=2., validate_args=True) x = np.linspace(1., 8., num=10).astype(np.float32) y = np.linspace( -np.expm1(-1 / 400.), -np.expm1(-16), num=10).astype(np.float32) - assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + assert_bijective_and_finite(bijector, x, y, event_ndims=0, rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py index 545471907f1eabc822b3d28ea9c57e183a09ff50..4e8989b6c2f93560b1fccbc99491d7809f494263 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py @@ -44,6 +44,7 @@ class _ChooseLocation(ConditionalBijector): graph_parents=[self._loc], is_constant_jacobian=True, validate_args=False, + forward_min_event_ndims=0, name=name) def _forward(self, x, z): @@ -52,7 +53,7 @@ class _ChooseLocation(ConditionalBijector): def _inverse(self, x, z): return x - self._gather_loc(z) - def _inverse_log_det_jacobian(self, x, z=None): + def _inverse_log_det_jacobian(self, x, event_ndims, z=None): return 0. def _gather_loc(self, z): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 507ceb35853ebe0a996d789b3bdf8a5f2284549c..68e0d9cb8277f3953039963fec0da499db7a16d1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -16,6 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib import distributions from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -25,23 +27,23 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -ds = distributions +tfd = distributions class DistributionTest(test.TestCase): def testParamShapesAndFromParams(self): classes = [ - ds.Normal, - ds.Bernoulli, - ds.Beta, - ds.Chi2, - ds.Exponential, - ds.Gamma, - ds.InverseGamma, - ds.Laplace, - ds.StudentT, - ds.Uniform, + tfd.Normal, + tfd.Bernoulli, + tfd.Beta, + tfd.Chi2, + tfd.Exponential, + tfd.Gamma, + tfd.InverseGamma, + tfd.Laplace, + tfd.StudentT, + tfd.Uniform, ] sample_shapes = [(), (10,), (10, 20, 30)] @@ -63,15 +65,15 @@ class DistributionTest(test.TestCase): with self.test_session(): # Note: we cannot easily test all distributions since each requires # different initialization arguments. We therefore spot test a few. - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) self.assertEqual(normal.parameters, normal.copy().parameters) - wishart = ds.WishartFull(df=2, scale=[[1., 2], [2, 5]], - validate_args=True) + wishart = tfd.WishartFull(df=2, scale=[[1., 2], [2, 5]], + validate_args=True) self.assertEqual(wishart.parameters, wishart.copy().parameters) def testCopyOverride(self): with self.test_session(): - normal = ds.Normal(loc=1., scale=2., validate_args=True) + normal = tfd.Normal(loc=1., scale=2., validate_args=True) unused_normal_copy = normal.copy(validate_args=False) base_params = normal.parameters.copy() copy_params = normal.copy(validate_args=False).parameters.copy() @@ -84,19 +86,19 @@ class DistributionTest(test.TestCase): mu = 1. sigma = 2. - normal = ds.Normal(mu, sigma, validate_args=True) + normal = tfd.Normal(mu, sigma, validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertTrue(tensor_util.constant_value(normal.is_scalar_batch())) - normal = ds.Normal([mu], [sigma], validate_args=True) + normal = tfd.Normal([mu], [sigma], validate_args=True) self.assertTrue(tensor_util.constant_value(normal.is_scalar_event())) self.assertFalse(tensor_util.constant_value(normal.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([mu], [sigma], validate_args=True) + mvn = tfd.MultivariateNormalDiag([mu], [sigma], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertTrue(tensor_util.constant_value(mvn.is_scalar_batch())) - mvn = ds.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) + mvn = tfd.MultivariateNormalDiag([[mu]], [[sigma]], validate_args=True) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_event())) self.assertFalse(tensor_util.constant_value(mvn.is_scalar_batch())) @@ -126,7 +128,7 @@ class DistributionTest(test.TestCase): self.assertFalse(is_scalar.eval(feed_dict={x: [1]})) def _GetFakeDistribution(self): - class FakeDistribution(ds.Distribution): + class FakeDistribution(tfd.Distribution): """Fake Distribution for testing _set_sample_static_shape.""" def __init__(self, batch_shape=None, event_shape=None): @@ -188,6 +190,105 @@ class DistributionTest(test.TestCase): y = dist._set_sample_static_shape(x, sample_shape) self.assertTrue(y.get_shape().ndims is None) + def testStrWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + ("tf.distributions.Normal(" + "\"Normal\", " + "batch_shape=(), " + "event_shape=(), " + "dtype=float16)"), # Got the dtype right. + str(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + ("tf.distributions.Chi2(" + "\"silly\", " # What a silly name that is! + "batch_shape=(2,), " + "event_shape=(), " + "dtype=float32)"), + str(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("tf.distributions.Exponential(\"Exponential\", " + # No batch shape. + "event_shape=(), " + "dtype=float32)"), + str(exp)) + + def testStrWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN\", " + "batch_shape=(2,), " + "event_shape=(2,), " + "dtype=float64)"), + str(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + ("tf.distributions.MultivariateNormalDiag(" + "\"MVN2\", " + "batch_shape=(?,), " # Partially known. + "event_shape=(3,), " + "dtype=float32)"), + str(mvn_dynamic)) + + def testReprWorksCorrectlyScalar(self): + normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) + self.assertEqual( + (""), # Got the dtype right. + repr(normal)) + + chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") + self.assertEqual( + (""), + repr(chi2)) + + exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) + self.assertEqual( + ("" + " event_shape=()" + " dtype=float32>"), + repr(exp)) + + def testReprWorksCorrectlyMultivariate(self): + mvn_static = tfd.MultivariateNormalDiag( + loc=np.zeros([2, 2]), name="MVN") + self.assertEqual( + (""), + repr(mvn_static)) + + mvn_dynamic = tfd.MultivariateNormalDiag( + loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32), + name="MVN2") + self.assertEqual( + (""), + repr(mvn_dynamic)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py index 06318ca09dec851cf025fa35c83732b85824cbee..6a69f9e60b99a17c657f074597a075890265a93b 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -126,6 +127,100 @@ class ProductDistributionTest(test.TestCase): self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) + def testKLRaises(self): + ind1 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=1) + ind2 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32(-1), + scale=np.float32(0.5)), + reinterpreted_batch_ndims=0) + + with self.assertRaisesRegexp( + ValueError, "Event shapes do not match"): + kullback_leibler.kl_divergence(ind1, ind2) + + ind1 = independent_lib.Independent( + distribution=normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=1) + ind2 = independent_lib.Independent( + distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([-1., 1]), + scale_diag=np.float32([0.1, 0.5])), + reinterpreted_batch_ndims=0) + + with self.assertRaisesRegexp( + NotImplementedError, "different event shapes"): + kullback_leibler.kl_divergence(ind1, ind2) + + def testKLScalarToMultivariate(self): + normal1 = normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])) + ind1 = independent_lib.Independent( + distribution=normal1, reinterpreted_batch_ndims=1) + + normal2 = normal_lib.Normal( + loc=np.float32([-3., 3]), + scale=np.float32([0.3, 0.3])) + ind2 = independent_lib.Independent( + distribution=normal2, reinterpreted_batch_ndims=1) + + normal_kl = kullback_leibler.kl_divergence(normal1, normal2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)), + self.evaluate(ind_kl)) + + def testKLIdentity(self): + normal1 = normal_lib.Normal( + loc=np.float32([-1., 1]), + scale=np.float32([0.1, 0.5])) + # This is functionally just a wrapper around normal1, + # and doesn't change any outputs. + ind1 = independent_lib.Independent( + distribution=normal1, reinterpreted_batch_ndims=0) + + normal2 = normal_lib.Normal( + loc=np.float32([-3., 3]), + scale=np.float32([0.3, 0.3])) + # This is functionally just a wrapper around normal2, + # and doesn't change any outputs. + ind2 = independent_lib.Independent( + distribution=normal2, reinterpreted_batch_ndims=0) + + normal_kl = kullback_leibler.kl_divergence(normal1, normal2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(normal_kl), self.evaluate(ind_kl)) + + def testKLMultivariateToMultivariate(self): + # (1, 1, 2) batch of MVNDiag + mvn1 = mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]), + scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]])) + ind1 = independent_lib.Independent( + distribution=mvn1, reinterpreted_batch_ndims=2) + + # (1, 1, 2) batch of MVNDiag + mvn2 = mvn_diag_lib.MultivariateNormalDiag( + loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]), + scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]])) + + ind2 = independent_lib.Independent( + distribution=mvn2, reinterpreted_batch_ndims=2) + + mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2) + ind_kl = kullback_leibler.kl_divergence(ind1, ind2) + self.assertAllClose( + self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])), + self.evaluate(ind_kl)) + def _testMnistLike(self, static_shape): sample_shape = [4, 5] batch_shape = [10] diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py index ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2..2980e2bfe93b2e2aa01d38fc9fa4650a015efc06 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -130,10 +130,8 @@ class KumaraswamyTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): + with self.assertRaisesOpError("sample must be non-negative"): dist.prob([-1., 0.1, 0.5]).eval() - with self.assertRaisesOpError("sample must be positive"): - dist.prob([0., 0.1, 0.5]).eval() with self.assertRaisesOpError("sample must be no larger than `1`"): dist.prob([.1, .2, 1.2]).eval() @@ -249,13 +247,13 @@ class KumaraswamyTest(test.TestCase): a = np.array([1., 2, 3]) b = np.array([2., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."): dist.mode().eval() a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."): dist.mode().eval() def testKumaraswamyModeEnableAllowNanStats(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 933756aa8e12cca4c42eb98d9193512bbf2ad585..9635134b08db47a47a17c869fe813e0376ae6f1e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -68,7 +68,7 @@ class MultivariateNormalDiagTest(test.TestCase): dist = ds.TransformedDistribution( base_dist, validate_args=True, - bijector=bijectors.Softplus(event_ndims=1)) + bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index d9c9008417cdb20b62390630cf887d3bd888a0d3..19a7472d91758a2dbd00c4d918853d7bae33685d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from scipy import special from scipy import stats from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import constant_op @@ -110,7 +111,7 @@ class PoissonTest(test.TestCase): batch_size = 6 lam = constant_op.constant([3.0] * batch_size) lam_v = 3.0 - x = [2.2, 3.1, 4., 5.5, 6., 7.] + x = [2., 3., 4., 5., 6., 7.] poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) @@ -121,12 +122,31 @@ class PoissonTest(test.TestCase): self.assertEqual(cdf.get_shape(), (6,)) self.assertAllClose(cdf.eval(), stats.poisson.cdf(x, lam_v)) + def testPoissonCDFNonIntegerValues(self): + with self.test_session(): + batch_size = 6 + lam = constant_op.constant([3.0] * batch_size) + lam_v = 3.0 + x = np.array([2.2, 3.1, 4., 5.5, 6., 7.], dtype=np.float32) + + poisson = self._make_poisson(rate=lam) + cdf = poisson.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + + # The Poisson CDF should be valid on these non-integer values, and + # equal to igammac(1 + x, rate). + self.assertAllClose(cdf.eval(), special.gammaincc(1. + x, lam_v)) + + with self.assertRaisesOpError("cannot contain fractional components"): + poisson_validate = self._make_poisson(rate=lam, validate_args=True) + poisson_validate.cdf(x).eval() + def testPoissonCdfMultidimensional(self): with self.test_session(): batch_size = 6 lam = constant_op.constant([[2.0, 4.0, 5.0]] * batch_size) lam_v = [2.0, 4.0, 5.0] - x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T + x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index 4186cf129dbf31724c84133734da3f226817c71a..ea04e8c29a2c94d4939bad277afa380401067ff2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test @@ -455,6 +456,16 @@ class PercentileTestWithNearestInterpolation(test.TestCase): with self.assertRaisesOpError("rank"): pct.eval(feed_dict={q_ph: [0.5]}) + def test_finds_max_of_long_array(self): + # d - 1 == d in float32 and d = 3e7. + # So this test only passes if we use double for the percentile indices. + # If float is used, it fails with InvalidArgumentError about an index out of + # bounds. + x = math_ops.linspace(0., 3e7, num=int(3e7)) + with self.test_session(): + minval = sample_stats.percentile(x, q=0, validate_args=True) + self.assertAllEqual(0, minval.eval()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py new file mode 100644 index 0000000000000000000000000000000000000000..968057331787059240110b90545f70c0ab128aa8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SeedStream class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import seed_stream +from tensorflow.python.platform import test + + +class SeedStreamTest(test.TestCase): + + def assertAllUnique(self, items): + self.assertEqual(len(items), len(set(items))) + + def testNonRepetition(self): + # The probability of repetitions in a short stream from a correct + # PRNG is negligible; this test catches bugs that prevent state + # updates. + strm = seed_stream.SeedStream(seed=4, salt="salt") + output = [strm() for _ in range(50)] + self.assertEqual(sorted(output), sorted(list(set(output)))) + + def testReproducibility(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(seed=4, salt="salt") + strm3 = seed_stream.SeedStream(seed=4, salt="salt") + outputs = [strm1() for _ in range(50)] + self.assertEqual(outputs, [strm2() for _ in range(50)]) + self.assertEqual(outputs, [strm3() for _ in range(50)]) + + def testSeededDistinctness(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(seed=5, salt="salt") + self.assertAllUnique( + [strm1() for _ in range(50)] + [strm2() for _ in range(50)]) + + def testSaltedDistinctness(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(seed=4, salt="another salt") + self.assertAllUnique( + [strm1() for _ in range(50)] + [strm2() for _ in range(50)]) + + def testNestingRobustness(self): + # SeedStreams started from generated seeds should not collide with + # the master or with each other, even if the salts are the same. + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1(), salt="salt") + strm3 = seed_stream.SeedStream(strm1(), salt="salt") + outputs = [strm1() for _ in range(50)] + self.assertAllUnique( + outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py index c8d795c3f6afbec5b41755951174439f7703efb9..243b5a034859288b0e2e120f09258cfee77fbdea 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py @@ -584,7 +584,6 @@ class DistributionShapeTest(test.TestCase): def testDistributionShapeGetDimsStatic(self): with self.test_session(): - shaper = _DistributionShape(batch_ndims=0, event_ndims=0) shaper = _DistributionShape(batch_ndims=0, event_ndims=0) x = 1 self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ce6cf702d522792f1ad26066a3d9be42003a0e3c --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -0,0 +1,243 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the statistical testing library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class StatisticalTestingTest(test.TestCase): + + def test_dkwm_design_mean_one_sample_soundness(self): + thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] + false_fail_rates, false_pass_rates = np.meshgrid(rates, rates) + false_fail_rates = false_fail_rates.flatten().astype(np.float32) + false_pass_rates = false_pass_rates.flatten().astype(np.float32) + + detectable_discrepancies = [] + for false_pass_rate, false_fail_rate in zip( + false_pass_rates, false_fail_rates): + sufficient_n = st.min_num_samples_for_dkwm_mean_test( + thresholds, low=0., high=1., false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate) + detectable_discrepancies.append( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + sufficient_n, low=0., high=1., false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate)) + + detectable_discrepancies_ = self.evaluate(detectable_discrepancies) + for discrepancies, false_pass_rate, false_fail_rate in zip( + detectable_discrepancies_, false_pass_rates, false_fail_rates): + below_threshold = discrepancies <= thresholds + self.assertAllEqual( + np.ones_like(below_threshold, np.bool), below_threshold, + msg='false_pass_rate({}), false_fail_rate({})'.format( + false_pass_rate, false_fail_rate)) + + def test_dkwm_design_mean_two_sample_soundness(self): + thresholds = [1e-5, 1e-2, 1.1e-1, 0.9, 1., 1.02, 2., 10., 1e2, 1e5, 1e10] + rates = [1e-6, 1e-3, 1e-2, 1.1e-1, 0.2, 0.5, 0.7, 1.] + false_fail_rates, false_pass_rates = np.meshgrid(rates, rates) + false_fail_rates = false_fail_rates.flatten().astype(np.float32) + false_pass_rates = false_pass_rates.flatten().astype(np.float32) + + detectable_discrepancies = [] + for false_pass_rate, false_fail_rate in zip( + false_pass_rates, false_fail_rates): + [ + sufficient_n1, + sufficient_n2 + ] = st.min_num_samples_for_dkwm_mean_two_sample_test( + thresholds, low1=0., high1=1., low2=0., high2=1., + false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate) + + detectable_discrepancies.append( + st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + n1=sufficient_n1, + low1=0., + high1=1., + n2=sufficient_n2, + low2=0., + high2=1., + false_fail_rate=false_fail_rate, + false_pass_rate=false_pass_rate)) + + detectable_discrepancies_ = self.evaluate(detectable_discrepancies) + for discrepancies, false_pass_rate, false_fail_rate in zip( + detectable_discrepancies_, false_pass_rates, false_fail_rates): + below_threshold = discrepancies <= thresholds + self.assertAllEqual( + np.ones_like(below_threshold, np.bool), below_threshold, + msg='false_pass_rate({}), false_fail_rate({})'.format( + false_pass_rate, false_fail_rate)) + + def test_true_mean_confidence_interval_by_dkwm_one_sample(self): + rng = np.random.RandomState(seed=0) + + num_samples = 5000 + # 5000 samples is chosen to be enough to find discrepancies of + # size 0.1 or more with assurance 1e-6, as confirmed here: + with self.test_session() as sess: + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = sess.run(d) + self.assertLess(d, 0.1) + + # Test that the confidence interval computed for the mean includes + # 0.5 and excludes 0.4 and 0.6. + with self.test_session() as sess: + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = sess.run([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) + + def test_dkwm_mean_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is 0.5. + samples = rng.uniform(size=num_samples).astype(np.float32) + with self.test_session() as sess: + sess.run(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaisesOpError("Mean confidence interval too high"): + sess.run(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaisesOpError("Mean confidence interval too low"): + sess.run(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + + # 4000 samples is chosen to be enough to find discrepancies of + # size 0.2 or more with assurance 1e-6, as confirmed here: + with self.test_session() as sess: + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = sess.run(d) + self.assertLess(d, 0.2) + + # Test that the test assertion agrees that the standard + # uniform distribution has the same mean as itself. + samples1 = rng.uniform(size=num_samples).astype(np.float32) + samples2 = rng.uniform(size=num_samples).astype(np.float32) + with self.test_session() as sess: + sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + samples1 = rng.uniform(size=num_samples).astype(np.float32) + + # As established above, 4000 samples is enough to find discrepancies + # of size 0.2 or more with assurance 1e-6. + + with self.test_session() as sess: + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("samples1 has a smaller mean"): + sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) + + def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): + rng = np.random.RandomState(seed=0) + num_samples = 4000 + samples1 = rng.uniform(size=num_samples).astype(np.float32) + + # As established above, 4000 samples is enough to find discrepancies + # of size 0.2 or more with assurance 1e-6. + + with self.test_session() as sess: + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("samples2 has a smaller mean"): + sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) + + def test_dkwm_argument_validity_checking(self): + rng = np.random.RandomState(seed=0) + samples = rng.uniform( + low=[0., 1.], high=[1., 2.], size=(2500, 1, 2)).astype(np.float32) + + # Test that the test library complains if the given samples fall + # outside the purported bounds. + with self.test_session() as sess: + with self.assertRaisesOpError("maximum value exceeds expectations"): + sess.run(st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) + with self.assertRaisesOpError("minimum value falls below expectations"): + sess.run(st.true_mean_confidence_interval_by_dkwm( + samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) + _ = sess.run(op) + + def test_do_maximum_mean(self): + n = 117 + envelope = 0.02 # > 2 / n, but < 3 / n + rng = np.random.RandomState(seed=8) + samples = rng.uniform(size=n).astype(np.float32) + + # Compute the answer in TF using the code under test + with self.test_session() as sess: + envelope_t = ops.convert_to_tensor(envelope) + max_mean = st._do_maximum_mean(samples, envelope_t, 1) + max_mean = sess.run(max_mean) + + # Compute the correct answer for this case in numpy. In this + # example, `n` and `envelope` are such that `samples[2]` is the + # element that should be taken partially, regardless of the + # content of the `samples` array (see algorithm description in + # `../ops/statistical_testing.py`). + samples = sorted(samples) + weight = 1. / n - (envelope - 2. / n) + answer = samples[2] * weight + sum(samples[3:]) / n + envelope * 1. + self.assertAllClose(max_mean, answer, rtol=1e-9) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index cbaf74d3f66253ae5727e1ba579e2d49235b748e..5fe1331d2c34612e980c7b376367cd63b627533d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -36,6 +37,35 @@ ds = distributions la = linalg +class DummyMatrixTransform(bs.Bijector): + """Tractable matrix transformation. + + This is a non-sensical bijector that has forward/inverse_min_event_ndims=2. + The main use is to check that transformed distribution calculations are done + appropriately. + """ + + def __init__(self): + super(DummyMatrixTransform, self).__init__( + forward_min_event_ndims=2, + is_constant_jacobian=False, + validate_args=False, + name="dummy") + + def _forward(self, x): + return x + + def _inverse(self, y): + return y + + # Note: These jacobians don't make sense. + def _forward_log_det_jacobian(self, x): + return -linalg_ops.matrix_determinant(x) + + def _inverse_log_det_jacobian(self, x): + return linalg_ops.matrix_determinant(x) + + class TransformedDistributionTest(test.TestCase): def _cls(self): @@ -55,7 +85,7 @@ class TransformedDistributionTest(test.TestCase): # you may or may not need a reduce_sum. log_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.Exp(event_ndims=0)) + bijector=bs.Exp()) sp_dist = stats.lognorm(s=sigma, scale=np.exp(mu)) # sample @@ -87,7 +117,7 @@ class TransformedDistributionTest(test.TestCase): sigma = 2.0 abs_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.AbsoluteValue(event_ndims=0)) + bijector=bs.AbsoluteValue()) sp_normal = stats.norm(mu, sigma) # sample @@ -129,7 +159,7 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(grid, cdf_, rtol=1e-6, atol=0.) def testCachedSamples(self): - exp_forward_only = bs.Exp(event_ndims=0) + exp_forward_only = bs.Exp() exp_forward_only._inverse = self._make_unimplemented( "inverse") exp_forward_only._inverse_event_shape_tensor = self._make_unimplemented( @@ -153,7 +183,7 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(expected_log_pdf, log_pdf_val, rtol=1e-4, atol=0.) def testCachedSamplesInvert(self): - exp_inverse_only = bs.Exp(event_ndims=0) + exp_inverse_only = bs.Exp() exp_inverse_only._forward = self._make_unimplemented( "forward") exp_inverse_only._forward_event_shape_tensor = self._make_unimplemented( @@ -186,12 +216,14 @@ class TransformedDistributionTest(test.TestCase): standard_normal = ds.Normal(loc=0., scale=1.) multi_logit_normal = self._cls()( distribution=standard_normal, - bijector=softmax) - x = [[-np.log(3.), 0.], - [np.log(3), np.log(5)]] + bijector=softmax, + event_shape=[1]) + x = [[[-np.log(3.)], [0.]], + [[np.log(3)], [np.log(5)]]] y = softmax.forward(x).eval() - expected_log_pdf = (stats.norm(loc=0., scale=1.).logpdf(x) - - np.sum(np.log(y), axis=-1)) + expected_log_pdf = ( + np.squeeze(stats.norm(loc=0., scale=1.).logpdf(x)) - + np.sum(np.log(y), axis=-1)) self.assertAllClose(expected_log_pdf, multi_logit_normal.log_prob(y).eval()) self.assertAllClose( @@ -208,8 +240,11 @@ class TransformedDistributionTest(test.TestCase): int_identity = bs.Inline( forward_fn=array_ops.identity, inverse_fn=array_ops.identity, - inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), - forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + inverse_log_det_jacobian_fn=( + lambda y: math_ops.cast(0, dtypes.int32)), + forward_log_det_jacobian_fn=( + lambda x: math_ops.cast(0, dtypes.int32)), + forward_min_event_ndims=0, is_constant_jacobian=True) normal = self._cls()( distribution=ds.Normal(loc=0., scale=1.), @@ -245,9 +280,8 @@ class TransformedDistributionTest(test.TestCase): with self.test_session() as sess: exp2 = self._cls()( ds.Exponential(rate=0.25), - bijector=ds.bijectors.Affine( - scale_identity_multiplier=2., - event_ndims=0)) + bijector=ds.bijectors.AffineScalar(scale=2.) + ) log_prob = exp2.log_prob(1.) log_prob_ = sess.run(log_prob) base_log_prob = -0.5 * 0.25 + np.log(0.25) @@ -434,6 +468,82 @@ class ScalarToMultiTest(test.TestCase): event_shape=[3], validate_args=True) + def testMatrixEvent(self): + with self.test_session() as sess: + batch_shape = [2] + event_shape = [2, 3, 3] + batch_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_batch_shape") + event_shape_pl = array_ops.placeholder( + dtypes.int32, name="dynamic_event_shape") + feed_dict = {batch_shape_pl: np.array(batch_shape, dtype=np.int32), + event_shape_pl: np.array(event_shape, dtype=np.int32)} + + scale = 2. + loc = 0. + fake_mvn_dynamic = self._cls()( + distribution=ds.Normal( + loc=loc, + scale=scale), + bijector=DummyMatrixTransform(), + batch_shape=batch_shape_pl, + event_shape=event_shape_pl, + validate_args=True) + + fake_mvn_static = self._cls()( + distribution=ds.Normal( + loc=loc, + scale=scale), + bijector=DummyMatrixTransform(), + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=True) + + def actual_mvn_log_prob(x): + # This distribution is the normal PDF, reduced over the + # last 3 dimensions + a jacobian term which corresponds + # to the determinant of x. + return (np.sum( + stats.norm(loc, scale).logpdf(x), axis=(-1, -2, -3)) + + np.sum(np.linalg.det(x), axis=-1)) + + self.assertAllEqual([2, 3, 3], fake_mvn_static.event_shape) + self.assertAllEqual([2], fake_mvn_static.batch_shape) + + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.event_shape) + self.assertAllEqual(tensor_shape.TensorShape(None), + fake_mvn_dynamic.batch_shape) + + num_samples = 5e3 + for fake_mvn, feed_dict in ((fake_mvn_static, {}), + (fake_mvn_dynamic, feed_dict)): + # Ensure sample works by checking first, second moments. + y = fake_mvn.sample(int(num_samples), seed=0) + x = y[0:5, ...] + [ + x_, + fake_event_shape_, + fake_batch_shape_, + fake_log_prob_, + fake_prob_, + ] = sess.run([ + x, + fake_mvn.event_shape_tensor(), + fake_mvn.batch_shape_tensor(), + fake_mvn.log_prob(x), + fake_mvn.prob(x), + ], feed_dict=feed_dict) + + # Ensure all other functions work as intended. + self.assertAllEqual([5, 2, 2, 3, 3], x_.shape) + self.assertAllEqual([2, 3, 3], fake_event_shape_) + self.assertAllEqual([2], fake_batch_shape_) + self.assertAllClose(actual_mvn_log_prob(x_), fake_log_prob_, + atol=0., rtol=1e-6) + self.assertAllClose(np.exp(actual_mvn_log_prob(x_)), fake_prob_, + atol=0., rtol=1e-5) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py index c355adeedbfff1072281a81de726ddb0ece07882..1226c66113ec4b43f57371abf4983aef1a529ec1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py @@ -61,7 +61,7 @@ class VectorLaplaceDiagTest(test.TestCase): dist = ds.TransformedDistribution( base_dist, validate_args=True, - bijector=bijectors.Softplus(event_ndims=1)) + bijector=bijectors.Softplus()) samps = dist.sample(5) # Shape [5, 1, 3]. self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape()) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index 9044aa2850ae35f29cd48b0c5f54aa948bea0408..dcecce981f16a2d9e772d4e40062ff250725c3ac 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -390,6 +390,26 @@ class WishartCholeskyTest(test.TestCase): chol_scale, dtype=np.int32), validate_args=False) + def testSampleBroadcasts(self): + dims = 2 + batch_shape = [2, 3] + sample_shape = [2, 1] + scale = np.float32([ + [[1., 0.5], + [0.5, 1.]], + [[0.5, 0.25], + [0.25, 0.75]], + ]) + scale = np.reshape(np.concatenate([scale, scale, scale], axis=0), + batch_shape + [dims, dims]) + wishart = distributions.WishartFull(df=5, scale=scale) + x = wishart.sample(sample_shape, seed=42) + with self.test_session() as sess: + x_ = sess.run(x) + expected_shape = sample_shape + batch_shape + [dims, dims] + self.assertAllEqual(expected_shape, x.shape) + self.assertAllEqual(expected_shape, x_.shape) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 852298bf334666db003353d5fc8e172ffb738668..69f3d57ff000d6c9acc8aa9e3d0ad8d9cbb6bb3c 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -36,7 +36,8 @@ class Autoregressive(distribution_lib.Distribution): "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." [(Papamakarios et + al., 2016)][1] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -45,17 +46,18 @@ class Autoregressive(distribution_lib.Distribution): Practically speaking the autoregressive property means that there exists a permutation of the event coordinates such that each coordinate is a - diffeomorphic function of only preceding coordinates. [2] + diffeomorphic function of only preceding coordinates + [(van den Oord et al., 2016)][2]. #### Mathematical Details - The probability function is, + The probability function is ```none prob(x; fn, n) = fn(x).prob(x) ``` - And a sample is generated by, + And a sample is generated by ```none x = fn(...fn(fn(x0).sample()).sample()).sample() @@ -93,13 +95,15 @@ class Autoregressive(distribution_lib.Distribution): ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "Conditional Image Generation with PixelCNN Decoders." - Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex - Graves, Koray Kavukcuoglu. Arxiv, 2016. + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + + [2]: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, + Alex Graves, and Koray Kavukcuoglu. Conditional Image Generation with + PixelCNN Decoders. In _Neural Information Processing Systems_, 2016. https://arxiv.org/abs/1606.05328 """ diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5590cd552a915a3ecfc1912ee530baf79665a6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -0,0 +1,416 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The BatchReshape distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import distribution as distribution_lib + + +__all__ = [ + "BatchReshape", +] + + +class BatchReshape(distribution_lib.Distribution): + """The Batch-Reshaping distribution. + + This "meta-distribution" reshapes the batch dimensions of another + distribution. + + Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support + `-1` for flattening. + + #### Examples + + ```python + tfd = tf.contrib.distributions + + dtype = np.float32 + dims = 2 + new_batch_shape = [1, 2, 3] + old_batch_shape = [6] + + scale = np.ones(old_batch_shape + [dims], dtype) + mvn = tfd.MultivariateNormalDiag(scale_diag=scale) + reshape_mvn = tfd.BatchReshape( + distribution=mvn, + batch_shape=new_batch_shape, + validate_args=True) + + reshape_mvn.batch_shape + # ==> [1, 2, 3] + + x = reshape_mvn.sample(sample_shape=[4, 5]) + x.shape + # ==> [4, 5, 1, 2, 3, 2] == sample_shape + new_batch_shape + [dims] + + reshape_mvn.log_prob(x).shape + # ==> [4, 5, 1, 2, 3] == sample_shape + new_batch_shape + ``` + + """ + + def __init__(self, + distribution, + batch_shape, + validate_args=False, + allow_nan_stats=True, + name=None): + """Construct BatchReshape distribution. + + Args: + distribution: The base distribution instance to reshape. Typically an + instance of `Distribution`. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing the + new shape of the batch dimensions. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: The name to give Ops created by the initializer. + Default value: `"BatchReshape" + distribution.name`. + + Raises: + ValueError: if `batch_shape` is not a vector. + ValueError: if `batch_shape` has non-positive elements. + ValueError: if `batch_shape` size is not the same as a + `distribution.batch_shape` size. + """ + parameters = locals() + name = name or "BatchReshape" + distribution.name + self._distribution = distribution + with ops.name_scope(name, values=[batch_shape]) as name: + self._batch_shape_ = ops.convert_to_tensor( + batch_shape, + dtype=dtypes.int32, + name="batch_shape") + self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) + if self._batch_shape_static is not None: + self._batch_shape_static = np.int32(self._batch_shape_static) + self._runtime_assertions = validate_init_args( + self._distribution, + self._batch_shape_, + validate_args, + self._batch_shape_static) + super(BatchReshape, self).__init__( + dtype=self._distribution.dtype, + reparameterization_type=self._distribution.reparameterization_type, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=( + [self._batch_shape_] + + self._distribution._graph_parents), # pylint: disable=protected-access + name=name) + + @property + def distribution(self): + return self._distribution + + def _batch_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return array_ops.identity(self._batch_shape_) + + def _batch_shape(self): + return tensor_shape.TensorShape(self._batch_shape_static) + + def _event_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return array_ops.identity(self.distribution.event_shape_tensor()) + + def _event_shape(self): + return self.distribution.event_shape + + def _sample_n(self, n, seed=None): + with ops.control_dependencies(self._runtime_assertions): + x = self.distribution.sample(sample_shape=n, seed=seed) + new_shape = array_ops.concat([ + [n], + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + return array_ops.reshape(x, new_shape) + + def _log_prob(self, x): + return self._call_reshape_input_output( + self.distribution.log_prob, x) + + def _prob(self, x): + return self._call_reshape_input_output( + self.distribution.prob, x) + + def _log_cdf(self, x): + return self._call_reshape_input_output( + self.distribution.log_cdf, x) + + def _cdf(self, x): + return self._call_reshape_input_output( + self.distribution.cdf, x) + + def _log_survival_function(self, x): + return self._call_reshape_input_output( + self.distribution.log_survival_function, x) + + def _survival_function(self, x): + return self._call_reshape_input_output( + self.distribution.survival_function, x) + + def _entropy(self): + return self._call_and_reshape_output( + self.distribution.entropy, + [], + [tensor_shape.scalar()]) + + def _mean(self): + return self._call_and_reshape_output(self.distribution.mean) + + def _mode(self): + return self._call_and_reshape_output(self.distribution.mode) + + def _stddev(self): + return self._call_and_reshape_output(self.distribution.stddev) + + def _variance(self): + return self._call_and_reshape_output(self.distribution.variance) + + def _covariance(self): + return self._call_and_reshape_output( + self.distribution.covariance, + [self.event_shape_tensor()]*2, + [self.event_shape]*2) + + def _sample_shape(self, x): + """Computes graph and static `sample_shape`.""" + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = (array_ops.size(self.batch_shape_tensor()) + if self.batch_shape.ndims is None + else self.batch_shape.ndims) + sample_ndims = x_ndims - batch_ndims - event_ndims + if isinstance(sample_ndims, int): + static_sample_shape = x.shape[:sample_ndims] + else: + static_sample_shape = tensor_shape.TensorShape(None) + if static_sample_shape.is_fully_defined(): + sample_shape = np.int32(static_sample_shape.as_list()) + else: + sample_shape = array_ops.shape(x)[:sample_ndims] + return sample_shape, static_sample_shape + + def _call_reshape_input_output(self, fn, x): + """Calls `fn`, appropriately reshaping its input `x` and output.""" + with ops.control_dependencies( + self._runtime_assertions + self._validate_sample_arg(x)): + sample_shape, static_sample_shape = self._sample_shape(x) + old_shape = array_ops.concat([ + sample_shape, + self.distribution.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + result = fn(array_ops.reshape(x, old_shape)) + new_shape = array_ops.concat([ + sample_shape, + self.batch_shape_tensor(), + ], axis=0) + result = array_ops.reshape(result, new_shape) + if (static_sample_shape.ndims is not None and + self.batch_shape.ndims is not None): + new_shape = static_sample_shape.concatenate(self.batch_shape) + result.set_shape(result.shape.merge_with(new_shape)) + return result + + def _call_and_reshape_output( + self, + fn, + event_shape_list=None, + static_event_shape_list=None): + """Calls `fn` and appropriately reshapes its output.""" + with ops.control_dependencies(self._runtime_assertions): + if event_shape_list is None: + event_shape_list = [self._event_shape_tensor()] + if static_event_shape_list is None: + static_event_shape_list = [self.event_shape] + new_shape = array_ops.concat( + [self.batch_shape_tensor()] + event_shape_list, + axis=0) + result = array_ops.reshape(fn(), new_shape) + if (self.batch_shape.ndims is not None and + self.event_shape.ndims is not None): + event_shape = tensor_shape.TensorShape([]) + for rss in static_event_shape_list: + event_shape = event_shape.concatenate(rss) + static_shape = result.shape.merge_with( + self.batch_shape.concatenate(event_shape)) + result.set_shape(static_shape) + return result + + def _validate_sample_arg(self, x): + """Helper which validates sample arg, e.g., input to `log_prob`.""" + with ops.name_scope(name="validate_sample_arg", values=[x]): + x_ndims = (array_ops.rank(x) if x.shape.ndims is None else x.shape.ndims) + event_ndims = (array_ops.size(self.event_shape_tensor()) + if self.event_shape.ndims is None + else self.event_shape.ndims) + batch_ndims = (array_ops.size(self.batch_shape_tensor()) + if self.batch_shape.ndims is None + else self.batch_shape.ndims) + expected_batch_event_ndims = batch_ndims + event_ndims + + if (isinstance(x_ndims, int) and + isinstance(expected_batch_event_ndims, int)): + if x_ndims < expected_batch_event_ndims: + raise NotImplementedError( + "Broadcasting is not supported; too few batch and event dims " + "(expected at least {}, saw {}).".format( + expected_batch_event_ndims, x_ndims)) + ndims_assertion = [] + elif self.validate_args: + ndims_assertion = [ + check_ops.assert_greater_equal( + x_ndims, + expected_batch_event_ndims, + message=("Broadcasting is not supported; too few " + "batch and event dims."), + name="assert_batch_and_event_ndims_large_enough"), + ] + + if (self.batch_shape.is_fully_defined() and + self.event_shape.is_fully_defined()): + expected_batch_event_shape = np.int32(self.batch_shape.concatenate( + self.event_shape).as_list()) + else: + expected_batch_event_shape = array_ops.concat([ + self.batch_shape_tensor(), + self.event_shape_tensor(), + ], axis=0) + + sample_ndims = x_ndims - expected_batch_event_ndims + if isinstance(sample_ndims, int): + sample_ndims = max(sample_ndims, 0) + if (isinstance(sample_ndims, int) and + x.shape[sample_ndims:].is_fully_defined()): + actual_batch_event_shape = np.int32(x.shape[sample_ndims:].as_list()) + else: + sample_ndims = math_ops.maximum(sample_ndims, 0) + actual_batch_event_shape = array_ops.shape(x)[sample_ndims:] + + if (isinstance(expected_batch_event_shape, np.ndarray) and + isinstance(actual_batch_event_shape, np.ndarray)): + if any(expected_batch_event_shape != actual_batch_event_shape): + raise NotImplementedError("Broadcasting is not supported; " + "unexpected batch and event shape " + "(expected {}, saw {}).".format( + expected_batch_event_shape, + actual_batch_event_shape)) + # We need to set the final runtime-assertions to `ndims_assertion` since + # its possible this assertion was created. We could add a condition to + # only do so if `self.validate_args == True`, however this is redundant + # as `ndims_assertion` already encodes this information. + runtime_assertions = ndims_assertion + elif self.validate_args: + # We need to make the `ndims_assertion` a control dep because otherwise + # TF itself might raise an exception owing to this assertion being + # ill-defined, ie, one cannot even compare different rank Tensors. + with ops.control_dependencies(ndims_assertion): + shape_assertion = check_ops.assert_equal( + expected_batch_event_shape, + actual_batch_event_shape, + message=("Broadcasting is not supported; " + "unexpected batch and event shape."), + name="assert_batch_and_event_shape_same") + runtime_assertions = [shape_assertion] + else: + runtime_assertions = [] + + return runtime_assertions + + +def validate_init_args( + distribution, + batch_shape, + validate_args, + batch_shape_static): + """Helper to __init__ which makes or raises assertions.""" + with ops.name_scope(name="validate_init_args", + values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access + runtime_assertions = [] + + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format( + batch_shape.shape.ndims)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_rank( + batch_shape, + 1, + message="`batch_shape` must be a vector.", + name="assert_batch_shape_is_vector"), + ] + + batch_size_static = np.prod(batch_shape_static) + dist_batch_size_static = ( + None if not distribution.batch_shape.is_fully_defined() + else np.prod(distribution.batch_shape).value) + + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, + dist_batch_size_static)) + elif validate_args: + runtime_assertions += [ + check_ops.assert_equal( + math_ops.reduce_prod(batch_shape), + math_ops.reduce_prod(distribution.batch_shape_tensor()), + message=("`batch_shape` size must match " + "`distributions.batch_shape` size."), + name="assert_batch_size"), + ] + + if batch_shape_static is not None: + if np.any(batch_shape_static < 1): + raise ValueError("`batch_shape` elements must be positive " + "(i.e., larger than zero).") + elif validate_args: + runtime_assertions += [ + check_ops.assert_positive( + batch_shape, + message=("`batch_shape` elements must be positive " + "(i.e., larger than zero)."), + name="assert_batch_shape_positive") + ] + + return runtime_assertions diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 93923c3f083c7f5136b55e9021cbd6323684b976..babce80396cfc41b53e99f91038d4f077c7efe82 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -17,7 +17,9 @@ @@AbsoluteValue @@Affine @@AffineLinearOperator +@@AffineScalar @@Bijector +@@BatchNormalization @@Chain @@CholeskyOuterProduct @@ConditionalBijector @@ -26,16 +28,18 @@ @@Identity @@Inline @@Invert +@@Kumaraswamy @@MaskedAutoregressiveFlow @@Permute @@PowerTransform @@RealNVP @@Reshape @@Sigmoid -@@SigmoidCentered @@SinhArcsinh @@SoftmaxCentered @@Softplus +@@Softsign +@@Square @@Weibull @@masked_autoregressive_default_template @@ -52,6 +56,8 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import * from tensorflow.contrib.distributions.python.ops.bijectors.affine import * from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import * +from tensorflow.contrib.distributions.python.ops.bijectors.affine_scalar import * +from tensorflow.contrib.distributions.python.ops.bijectors.batch_normalization import * from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * @@ -59,16 +65,18 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * +from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * +from tensorflow.contrib.distributions.python.ops.bijectors.square import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7..c9e31d7712f09f6c4b4cc6ae51a34c42a19c291d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,9 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import constant_op from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -72,38 +70,22 @@ class AbsoluteValue(bijector.Bijector): """ - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. validate_args: Python `bool` indicating whether arguments should be checked for correctness, in particular whether inputs to `inverse` and `inverse_log_det_jacobian` are non-negative. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. """ self._graph_parents = [] self._name = name - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - with self._name_scope("init"): super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, + is_constant_jacobian=True, validate_args=validate_args, name=name) @@ -121,8 +103,7 @@ class AbsoluteValue(bijector.Bijector): # If event_ndims = 2, # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + zeros = constant_op.constant(0., dtype=y.dtype) if self.validate_args: zeros = control_flow_ops.with_dependencies( [check_ops.assert_non_negative(y, message="Argument y was negative")], diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 05bb9c2f9bdf35e222c94db3491157893da64ebd..b4c2939eb914d50475ba6b1c1e979a804090f641 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -62,7 +62,7 @@ class Affine(bijector.Bijector): matrices, i.e., the matmul is [matrix-free]( https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - Examples: + #### Examples ```python # Y = X @@ -104,7 +104,6 @@ class Affine(bijector.Bijector): scale_tril=None, scale_perturb_factor=None, scale_perturb_diag=None, - event_ndims=1, validate_args=False, name="affine"): """Instantiates the `Affine` bijector. @@ -157,8 +156,6 @@ class Affine(bijector.Bijector): matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which represents an `r x r` diagonal matrix. When `None` low rank updates will take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -187,22 +184,6 @@ class Affine(bijector.Bijector): with self._name_scope("init", values=[ shift, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -251,12 +232,11 @@ class Affine(bijector.Bijector): self._scale = scale self._shaper = _DistributionShape( batch_ndims=batch_ndims, - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args) super(Affine, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, graph_parents=( - [event_ndims] + [self._scale] if tensor_util.is_tensor(self._scale) else self._scale.graph_parents + [self._shift] if self._shift is not None else []), @@ -381,18 +361,17 @@ class Affine(bijector.Bijector): x, sample_shape, expand_batch_dim=False) return x - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] + event_size = array_ops.shape(x)[-1] event_size = math_ops.cast(event_size, dtype=self._scale.dtype) return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() def _maybe_check_scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 89043b1410370074f11f2cfa59b6b6663fa62521..59f9742d576a7804f401d3a47ba31ae61d6c6e54 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -22,9 +22,6 @@ from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator @@ -94,7 +91,6 @@ class AffineLinearOperator(bijector.Bijector): def __init__(self, shift=None, scale=None, - event_ndims=1, validate_args=False, name="affine_linear_operator"): """Instantiates the `AffineLinearOperator` bijector. @@ -103,14 +99,11 @@ class AffineLinearOperator(bijector.Bijector): shift: Floating-point `Tensor`. scale: Subclass of `LinearOperator`. Represents the (batch) positive definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. Raises: - ValueError: if `event_ndims` is not 0 or 1. TypeError: if `scale` is not a `LinearOperator`. TypeError: if `shift.dtype` does not match `scale.dtype`. ValueError: if not `scale.is_non_singular`. @@ -120,20 +113,6 @@ class AffineLinearOperator(bijector.Bijector): self._validate_args = validate_args graph_parents = [] with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -166,10 +145,10 @@ class AffineLinearOperator(bijector.Bijector): self._scale = scale self._shaper = _DistributionShape( batch_ndims=batch_ndims, - event_ndims=event_ndims, + event_ndims=1, validate_args=validate_args) super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, graph_parents=graph_parents, is_constant_jacobian=True, dtype=dtype, @@ -213,12 +192,13 @@ class AffineLinearOperator(bijector.Bijector): x, sample_shape, expand_batch_dim=False) return x - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) + return constant_op.constant(0., dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if self.validate_args else []): return self.scale.log_abs_determinant() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py new file mode 100644 index 0000000000000000000000000000000000000000..cd792e2c8cf48602daf9fb5eb56b8c34bac050c7 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -0,0 +1,141 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Affine bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "AffineScalar", +] + + +class AffineScalar(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale * X + shift`. + + Examples: + + ```python + # Y = X + b = AffineScalar() + + # Y = X + shift + b = AffineScalar(shift=[1., 2, 3]) + + # Y = 2 * X + shift + b = AffineScalar( + shift=[1., 2, 3], + scale=2.) + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + validate_args=False, + name="affine_scalar"): + """Instantiates the `AffineScalar` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale * X + shift + ``` + + if `scale` is not specified, then the bijector has the semantics of + `scale = 1.`. Similarly, if `shift` is not specified, then the bijector + has the semantics of `shift = 0.`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale: Floating-point `Tensor`. If this is set to `None`, no scale is + applied. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + with self._name_scope("init", values=[scale, shift]): + self._shift = shift + self._scale = scale + + if self._shift is not None: + self._shift = ops.convert_to_tensor(shift, name="shift") + + if self._scale is not None: + self._scale = ops.convert_to_tensor(self._scale, name="scale") + if validate_args: + self._scale = control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + self._scale, + array_ops.zeros([], dtype=self._scale.dtype))], + self._scale) + + super(AffineScalar, self).__init__( + forward_min_event_ndims=0, + is_constant_jacobian=True, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = array_ops.identity(x) + if self.scale is not None: + y *= self.scale + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = array_ops.identity(y) + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x /= self.scale + return x + + def _forward_log_det_jacobian(self, x): + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. + if self.scale is None: + return constant_op.constant(0., dtype=x.dtype.base_dtype) + + return math_ops.log(math_ops.abs(self.scale)) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..224cec8a63dba53a528490117efac890312fe8d5 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -0,0 +1,263 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Batch Norm bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "BatchNormalization", +] + + +def _undo_batch_normalization(x, + mean, + variance, + offset, + scale, + variance_epsilon, + name=None): + r"""Inverse of tf.nn.batch_normalization. + + Args: + x: Input `Tensor` of arbitrary dimensionality. + mean: A mean `Tensor`. + variance: A variance `Tensor`. + offset: An offset `Tensor`, often denoted `beta` in equations, or + None. If present, will be added to the normalized tensor. + scale: A scale `Tensor`, often denoted `gamma` in equations, or + `None`. If present, the scale is applied to the normalized tensor. + variance_epsilon: A small `float` added to the minibatch `variance` to + prevent dividing by zero. + name: A name for this operation (optional). + + Returns: + batch_unnormalized: The de-normalized, de-scaled, de-offset `Tensor`. + """ + with ops.name_scope( + name, "undo_batchnorm", [x, mean, variance, scale, offset]): + # inv = math_ops.rsqrt(variance + variance_epsilon) + # if scale is not None: + # inv *= scale + # return x * inv + ( + # offset - mean * inv if offset is not None else -mean * inv) + rescale = math_ops.sqrt(variance + variance_epsilon) + if scale is not None: + rescale /= scale + batch_unnormalized = x * rescale + ( + mean - offset * rescale if offset is not None else mean) + return batch_unnormalized + + +class BatchNormalization(bijector.Bijector): + """Compute `Y = g(X) s.t. X = g^-1(Y) = (Y - mean(Y)) / std(Y)`. + + Applies Batch Normalization [(Ioffe and Szegedy, 2015)][1] to samples from a + data distribution. This can be used to stabilize training of normalizing + flows ([Papamakarios et al., 2016][3]; [Dinh et al., 2017][2]) + + When training Deep Neural Networks (DNNs), it is common practice to + normalize or whiten features by shifting them to have zero mean and + scaling them to have unit variance. + + The `inverse()` method of the `BatchNormalization` bijector, which is used in + the log-likelihood computation of data samples, implements the normalization + procedure (shift-and-scale) using the mean and standard deviation of the + current minibatch. + + Conversely, the `forward()` method of the bijector de-normalizes samples (e.g. + `X*std(Y) + mean(Y)` with the running-average mean and standard deviation + computed at training-time. De-normalization is useful for sampling. + + ```python + + dist = tfd.TransformedDistribution( + distribution=tfd.Normal()), + bijector=tfb.BatchNorm()) + + y = tfd.MultivariateNormalDiag(loc=1., scale=2.).sample(100) # ~ N(1, 2) + x = dist.bijector.inverse(y) # ~ N(0, 1) + y = dist.sample() # ~ N(1, 2) + ``` + + During training time, `BatchNorm.inverse` and `BatchNorm.forward` are not + guaranteed to be inverses of each other because `inverse(y)` uses statistics + of the current minibatch, while `forward(x)` uses running-average statistics + accumulated from training. In other words, + `BatchNorm.inverse(BatchNorm.forward(...))` and + `BatchNorm.forward(BatchNorm.inverse(...))` will be identical when + `training=False` but may be different when `training=True`. + + #### References + + [1]: Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating + Deep Network Training by Reducing Internal Covariate Shift. In + _International Conference on Machine Learning_, 2015. + https://arxiv.org/abs/1502.03167 + + [2]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 + """ + + def __init__(self, + batchnorm_layer=None, + training=True, + validate_args=False, + name="batch_normalization"): + """Instantiates the `BatchNorm` bijector. + + Args: + batchnorm_layer: `tf.layers.BatchNormalization` layer object. If `None`, + defaults to + `tf.layers.BatchNormalization(gamma_constraint=nn_ops.relu(x) + 1e-6)`. + This ensures positivity of the scale variable. + + training: If True, updates running-average statistics during call to + `inverse()`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + Raises: + ValueError: If bn_layer is not an instance of + `tf.layers.BatchNormalization`, or if it is specified with `renorm=True` + or a virtual batch size. + """ + # Scale must be positive. + g_constraint = lambda x: nn.relu(x) + 1e-6 + self.batchnorm = batchnorm_layer or normalization.BatchNormalization( + gamma_constraint=g_constraint) + self._validate_bn_layer(self.batchnorm) + self._training = training + if isinstance(self.batchnorm.axis, int): + forward_min_event_ndims = 1 + else: + forward_min_event_ndims = len(self.batchnorm.axis) + super(BatchNormalization, self).__init__( + forward_min_event_ndims=forward_min_event_ndims, + validate_args=validate_args, name=name) + + def _validate_bn_layer(self, layer): + """Check for valid BatchNormalization layer. + + Args: + layer: Instance of `tf.layers.BatchNormalization`. + Raises: + ValueError: If batchnorm_layer argument is not an instance of + `tf.layers.BatchNormalization`, or if `batchnorm_layer.renorm=True` or + if `batchnorm_layer.virtual_batch_size` is specified. + """ + if not isinstance(layer, normalization.BatchNormalization): + raise ValueError( + "batchnorm_layer must be an instance of BatchNormalization layer.") + if layer.renorm: + raise ValueError("BatchNorm Bijector does not support renormalization.") + if layer.virtual_batch_size: + raise ValueError( + "BatchNorm Bijector does not support virtual batch sizes.") + + def _get_broadcast_fn(self, x): + # Compute shape to broadcast scale/shift parameters to. + if not x.shape.is_fully_defined(): + raise ValueError("Input must have shape known at graph construction.") + input_shape = np.int32(x.shape.as_list()) + + ndims = len(input_shape) + reduction_axes = [i for i in range(ndims) if i not in self.batchnorm.axis] + # Broadcasting only necessary for single-axis batch norm where the axis is + # not the last dimension + broadcast_shape = [1] * ndims + broadcast_shape[self.batchnorm.axis[0]] = ( + input_shape[self.batchnorm.axis[0]]) + def _broadcast(v): + if (v is not None and + len(v.get_shape()) != ndims and + reduction_axes != list(range(ndims - 1))): + return array_ops.reshape(v, broadcast_shape) + return v + return _broadcast + + def _normalize(self, y): + return self.batchnorm.apply(y, training=self._training) + + def _de_normalize(self, x): + # Uses the saved statistics. + if not self.batchnorm.built: + input_shape = x.get_shape() + self.batchnorm.build(input_shape) + broadcast_fn = self._get_broadcast_fn(x) + mean = broadcast_fn(self.batchnorm.moving_mean) + variance = broadcast_fn(self.batchnorm.moving_variance) + beta = broadcast_fn(self.batchnorm.beta) if self.batchnorm.center else None + gamma = broadcast_fn(self.batchnorm.gamma) if self.batchnorm.scale else None + return _undo_batch_normalization( + x, mean, variance, beta, gamma, self.batchnorm.epsilon) + + def _forward(self, x): + return self._de_normalize(x) + + def _inverse(self, y): + return self._normalize(y) + + def _forward_log_det_jacobian(self, x): + # Uses saved statistics to compute volume distortion. + return -self._inverse_log_det_jacobian(x, use_saved_statistics=True) + + def _inverse_log_det_jacobian(self, y, use_saved_statistics=False): + if not y.shape.is_fully_defined(): + raise ValueError("Input must have shape known at graph construction.") + input_shape = np.int32(y.shape.as_list()) + + if not self.batchnorm.built: + # Create variables. + self.batchnorm.build(input_shape) + + event_dims = self.batchnorm.axis + reduction_axes = [i for i in range(len(input_shape)) if i not in event_dims] + + if use_saved_statistics or not self._training: + log_variance = math_ops.log( + self.batchnorm.moving_variance + self.batchnorm.epsilon) + else: + # At training-time, ildj is computed from the mean and log-variance across + # the current minibatch. + _, v = nn.moments(y, axes=reduction_axes, keep_dims=True) + log_variance = math_ops.log(v + self.batchnorm.epsilon) + + # `gamma` and `log Var(y)` reductions over event_dims. + # Log(total change in area from gamma term). + log_total_gamma = math_ops.reduce_sum(math_ops.log(self.batchnorm.gamma)) + + # Log(total change in area from log-variance term). + log_total_variance = math_ops.reduce_sum(log_variance) + # The ildj is scalar, as it does not depend on the values of x and are + # constant across minibatch elements. + return log_total_gamma - 0.5 * log_total_variance diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 3ce7c26213034c7345a20faa803c94a1bfa8d579..85ad23e4133ef09051cdc8b45e489caeea90fbb3 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -21,6 +21,9 @@ from __future__ import print_function import itertools from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector @@ -29,6 +32,91 @@ __all__ = [ ] +def _use_static_shape(input_tensor, ndims): + return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) + + +def _maybe_get_event_ndims_statically(event_ndims): + static_event_ndims = (event_ndims if isinstance(event_ndims, int) + else tensor_util.constant_value(event_ndims)) + if static_event_ndims is not None: + return static_event_ndims + + return event_ndims + + +def _compute_min_event_ndims(bijector_list, compute_forward=True): + """Computes the min_event_ndims associated with the give list of bijectors. + + Given a list `bijector_list` of bijectors, compute the min_event_ndims that is + associated with the composition of bijectors in that list. + + min_event_ndims is the # of right most dimensions for which the bijector has + done necessary computation on (i.e. the non-broadcastable part of the + computation). + + We can derive the min_event_ndims for a chain of bijectors as follows: + + In the case where there are no rank changing bijectors, this will simply be + `max(b.forward_min_event_ndims for b in bijector_list)`. This is because the + bijector with the most forward_min_event_ndims requires the most dimensions, + and hence the chain also requires operating on those dimensions. + + However in the case of rank changing, more care is needed in determining the + exact amount of dimensions. Padding dimensions causes subsequent bijectors to + operate on the padded dimensions, and Removing dimensions causes bijectors to + operate more left. + + Args: + bijector_list: List of bijectors to be composed by chain. + compute_forward: Boolean. If True, computes the min_event_ndims associated + with a forward call to Chain, and otherwise computes the min_event_ndims + associated with an inverse call to Chain. The latter is the same as the + min_event_ndims associated with a forward call to Invert(Chain(....)). + + Returns: + min_event_ndims + """ + min_event_ndims = 0 + # This is a mouthful, but what this encapsulates is that if not for rank + # changing bijectors, we'd only need to compute the largest of the min + # required ndims. Hence "max_min". Due to rank changing bijectors, we need to + # account for synthetic rank growth / synthetic rank decrease from a rank + # changing bijector. + rank_changed_adjusted_max_min_event_ndims = 0 + + if compute_forward: + bijector_list = reversed(bijector_list) + + for b in bijector_list: + if compute_forward: + current_min_event_ndims = b.forward_min_event_ndims + current_inverse_min_event_ndims = b.inverse_min_event_ndims + else: + current_min_event_ndims = b.inverse_min_event_ndims + current_inverse_min_event_ndims = b.forward_min_event_ndims + + # New dimensions were touched. + if rank_changed_adjusted_max_min_event_ndims < current_min_event_ndims: + min_event_ndims += ( + current_min_event_ndims - rank_changed_adjusted_max_min_event_ndims) + rank_changed_adjusted_max_min_event_ndims = max( + current_min_event_ndims, rank_changed_adjusted_max_min_event_ndims) + + # If the number of dimensions has increased via forward, then + # inverse_min_event_ndims > forward_min_event_ndims, and hence the + # dimensions we computed on, have moved left (so we have operated + # on additional dimensions). + # Conversely, if the number of dimensions has decreased via forward, + # then we have inverse_min_event_ndims < forward_min_event_ndims, + # and so we will have operated on fewer right most dimensions. + + number_of_changed_dimensions = ( + current_min_event_ndims - current_inverse_min_event_ndims) + rank_changed_adjusted_max_min_event_ndims -= number_of_changed_dimensions + return min_event_ndims + + class Chain(bijector.Bijector): """Bijector which applies a sequence of bijectors. @@ -93,21 +181,24 @@ class Chain(bijector.Bijector): raise ValueError("incompatible dtypes: %s" % dtype) elif len(dtype) == 2: dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims elif len(dtype) == 1: dtype = dtype[0] - event_ndims = bijectors[0].event_ndims else: dtype = None - event_ndims = None + + inverse_min_event_ndims = _compute_min_event_ndims( + bijectors, compute_forward=False) + forward_min_event_ndims = _compute_min_event_ndims( + bijectors, compute_forward=True) super(Chain, self).__init__( graph_parents=list(itertools.chain.from_iterable( b.graph_parents for b in bijectors)), + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), validate_args=validate_args, dtype=dtype, - event_ndims=event_ndims, name=name or ("identity" if not bijectors else "_of_".join(["chain"] + [b.name for b in bijectors]))) @@ -147,10 +238,31 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") + ildj = constant_op.constant( + 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + + if not self.bijectors: + return ildj + + event_ndims = _maybe_get_event_ndims_statically( + self.inverse_min_event_ndims) + + if _use_static_shape(y, event_ndims): + event_shape = y.shape[y.shape.ndims - event_ndims:] + else: + event_shape = array_ops.shape(y)[array_ops.rank(y) - event_ndims:] + for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + ildj += b.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **kwargs.get(b.name, {})) + + if _use_static_shape(y, event_ndims): + event_shape = b.inverse_event_shape(event_shape) + event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + else: + event_shape = b.inverse_event_shape_tensor(event_shape) + event_ndims = _maybe_get_event_ndims_statically( + array_ops.rank(event_shape)) y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -160,9 +272,34 @@ class Chain(bijector.Bijector): return x def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") + x = ops.convert_to_tensor(x, name="x") + + fldj = constant_op.constant( + 0., dtype=x.dtype, name="inverse_log_det_jacobian") + + if not self.bijectors: + return fldj + + event_ndims = _maybe_get_event_ndims_statically( + self.forward_min_event_ndims) + + if _use_static_shape(x, event_ndims): + event_shape = x.shape[x.shape.ndims - event_ndims:] + else: + event_shape = array_ops.shape(x)[array_ops.rank(x) - event_ndims:] + for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + fldj += b.forward_log_det_jacobian( + x, event_ndims=event_ndims, **kwargs.get(b.name, {})) + if _use_static_shape(x, event_ndims): + event_shape = b.forward_event_shape(event_shape) + event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + else: + event_shape = b.forward_event_shape_tensor(event_shape) + event_ndims = _maybe_get_event_ndims_statically( + array_ops.rank(event_shape)) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj + diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..caae2adcfac7643cdc8f76dd1cccddd516105410 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -20,8 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -39,8 +37,6 @@ __all__ = [ class CholeskyOuterProduct(bijector.Bijector): """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - Note: the upper-triangular part of X is ignored (whether or not its zero). The surjectivity of g as a map from the set of n x n positive-diagonal @@ -61,49 +57,34 @@ class CholeskyOuterProduct(bijector.Bijector): that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - Examples: + #### Examples ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + bijector.CholeskyOuterProduct().forward(x=[[1., 0], [2, 1]]) # Result: [[1., 2], [2, 5]], i.e., x @ x.T - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + bijector.CholeskyOuterProduct().inverse(y=[[1., 2], [2, 5]]) # Result: [[1., 0], [2, 1]], i.e., cholesky(y). ``` """ - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): + def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. """ self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=2, validate_args=validate_args, name=name) def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) if self.validate_args: is_matrix = check_ops.assert_rank_at_least(x, 2) shape = array_ops.shape(x) @@ -114,11 +95,7 @@ class CholeskyOuterProduct(bijector.Bijector): return math_ops.matmul(x, x, adjoint_b=True) def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) + return linalg_ops.cholesky(y) def _forward_log_det_jacobian(self, x): # Let Y be a symmetric, positive definite matrix and write: @@ -161,13 +138,6 @@ class CholeskyOuterProduct(bijector.Bijector): # Since there is a 2 X[j,j] term for every lower-triangular element of X we # conclude: # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - diag = array_ops.matrix_diag_part(x) # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index ccb1f029277bc07011df7be047a075274f2b3a27..e9e994f839ab2fe0a0f52f5f404fb2a0c8f9cd94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -44,12 +44,16 @@ class ConditionalBijector(bijector.Bijector): "**condition_kwargs": "Named arguments forwarded to subclass implementation."}) def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + self, y, event_ndims, name="inverse_log_det_jacobian", + **condition_kwargs): + return self._call_inverse_log_det_jacobian( + y, event_ndims, name, **condition_kwargs) @distribution_util.AppendDocstring(kwargs_dict={ "**condition_kwargs": "Named arguments forwarded to subclass implementation."}) def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) + self, x, event_ndims, name="forward_log_det_jacobian", + **condition_kwargs): + return self._call_forward_log_det_jacobian( + x, event_ndims, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..9fc1bbf052b419d07a9db149b990c2b80190d72b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -33,8 +33,8 @@ class Exp(power_transform.PowerTransform): ```python # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) + # batch ndim 2. + exp = Exp() x = [[[1., 2], [3, 4]], [[5, 6], @@ -48,19 +48,17 @@ class Exp(power_transform.PowerTransform): """ def __init__(self, - event_ndims=0, validate_args=False, name="exp"): """Instantiates the `Exp` bijector. Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. """ + # forward_min_event_ndims = 0. + # No forward_min_event_ndims specified as this is done in PowerTransform. super(Exp, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index 67f39785563255be0fe154aca3cbcf01c6a01e73..e656a258e56e71898ecb719dd2af876f158cf799 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -48,7 +48,6 @@ class Gumbel(bijector.Bijector): def __init__(self, loc=0., scale=1., - event_ndims=0, validate_args=False, name="gumbel"): """Instantiates the `Gumbel` bijector. @@ -60,8 +59,6 @@ class Gumbel(bijector.Bijector): scale: Positive Float-like `Tensor` that is the same dtype and is broadcastable with `loc`. This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -80,7 +77,9 @@ class Gumbel(bijector.Bijector): ], self._scale) super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + validate_args=validate_args, + forward_min_event_ndims=0, + name=name) @property def loc(self): @@ -102,15 +101,11 @@ class Gumbel(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + return math_ops.log(self.scale / (-math_ops.log(y) * y)) def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + return -z - math_ops.exp(-z) - math_ops.log(self.scale) def _maybe_assert_valid_y(self, y): if not self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..2bde956d1345129285acae4684256c5ac828b9a1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -40,7 +40,7 @@ class Inline(bijector.Bijector): name="exp") ``` - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + The above example is equivalent to the `Bijector` `Exp()`. """ def __init__(self, @@ -54,6 +54,8 @@ class Inline(bijector.Bijector): inverse_event_shape_tensor_fn=None, is_constant_jacobian=False, validate_args=False, + forward_min_event_ndims=None, + inverse_min_event_ndims=None, name="inline"): """Creates a `Bijector` from callables. @@ -76,10 +78,15 @@ class Inline(bijector.Bijector): constant for all input arguments. validate_args: Python `bool` indicating whether arguments should be checked for correctness. + forward_min_event_ndims: Python `int` indicating the minimal + dimensionality this bijector acts on. + inverse_min_event_ndims: Python `int` indicating the minimal + dimensionality this bijector acts on. name: Python `str`, name given to ops managed by this object. """ super(Inline, self).__init__( - event_ndims=0, + forward_min_event_ndims=forward_min_event_ndims, + inverse_min_event_ndims=inverse_min_event_ndims, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -134,8 +141,8 @@ class Inline(bijector.Bijector): "inverse_log_det_jacobian_fn is not a callable function.") return self._inverse_log_det_jacobian_fn(y, **kwargs) - def _forward_log_det_jacobian(self, y, **kwargs): + def _forward_log_det_jacobian(self, x, **kwargs): if not callable(self._forward_log_det_jacobian_fn): raise NotImplementedError( "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) + return self._forward_log_det_jacobian_fn(x, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..1904239a0e7009c35cc4f3c8876fd749463a2b83 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -66,8 +66,9 @@ class Invert(bijector_lib.Bijector): self._bijector = bijector super(Invert, self).__init__( - event_ndims=bijector.event_ndims, graph_parents=bijector.graph_parents, + forward_min_event_ndims=bijector.inverse_min_event_ndims, + inverse_min_event_ndims=bijector.forward_min_event_ndims, is_constant_jacobian=bijector.is_constant_jacobian, validate_args=validate_args, dtype=bijector.dtype, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..97000c17262d3efdef10274711364c2bc2083bd4 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -0,0 +1,132 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kumaraswamy bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Kumaraswamy", +] + + +class Kumaraswamy(bijector.Bijector): + """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`. + + This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the [Kumaraswamy distribution]( + https://en.wikipedia.org/wiki/Kumaraswamy_distribution): + + ```none + Y ~ Kumaraswamy(a, b) + pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1) + ``` + """ + + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + name="kumaraswamy"): + """Instantiates the `Kumaraswamy` bijector. + + Args: + concentration1: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is + `concentration1`. + concentration0: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is + `concentration0`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + with self._name_scope("init", values=[concentration1, concentration0]): + concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args=validate_args) + concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args=validate_args) + + self._concentration1 = concentration1 + self._concentration0 = concentration0 + super(Kumaraswamy, self).__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + @property + def concentration1(self): + """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration1 + + @property + def concentration0(self): + """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration0 + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.exp( + math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0)) + / self.concentration1) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.exp(math_ops.log1p( + -(1 - y**self.concentration1)**self.concentration0)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid(y) + return ( + math_ops.log(self.concentration1) + math_ops.log(self.concentration0) + + (self.concentration1 - 1) * math_ops.log(y) + + (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1)) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid(self, x): + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + x, + message="sample must be non-negative"), + check_ops.assert_less_equal( + x, array_ops.ones([], self.concentration0.dtype), + message="sample must be no larger than `1`."), + ], x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 5251dbcb5748f75688aa43ce6e4e9dbd76be78bb..ef56cf6ddda4dca2b1575e844b2584689e531b81 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -45,14 +45,15 @@ __all__ = [ class MaskedAutoregressiveFlow(bijector_lib.Bijector): """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, + The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a + relatively simple framework for user-specified (deep) architectures to learn + a distribution over vector-valued events. Regarding terminology, "Autoregressive models decompose the joint density as a product of conditionals, and model each conditional in turn. Normalizing flows transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] + by an invertible transformation with tractable Jacobian." + [(Papamakarios et al., 2016)][3] In other words, the "autoregressive property" is equivalent to the decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided @@ -60,7 +61,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): this property by zeroing out weights in its `masked_dense` layers. In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression" is implemented using a `tf.while_loop` and a deep neural network (DNN) with masked weights such that the autoregressive property is automatically met in the `inverse`. @@ -75,26 +76,26 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): Given a `shift_and_log_scale_fn`, the forward and inverse transformations are (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. + must compute each `shift` (aka `loc` or "mu" in [Germain et al. (2015)][1]) + and `log(scale)` (aka "alpha" in [Germain et al. (2015)][1]) such that each + are broadcastable with the arguments to `forward` and `inverse`, i.e., such + that the calculations in `forward`, `inverse` [below] are possible. For convenience, `masked_autoregressive_default_template` is offered as a possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. + architecture [(Germain et al., 2015)][1]. MADE is a feed-forward network that + computes a `shift` and `log(scale)` using `masked_dense` layers in a deep + neural network. Weights are masked to ensure the autoregressive property. It + is possible that this architecture is suboptimal for your task. To build + alternative networks, either change the arguments to + `masked_autoregressive_default_template`, use the `masked_dense` function to + roll-out your own, or use some other architecture, e.g., using `tf.layers`. Warning: no attempt is made to validate that the `shift_and_log_scale_fn` enforces the "autoregressive property". Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, + semantics, the forward transformation is ```python def forward(x): @@ -106,7 +107,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): return y ``` - and the inverse transformation is, + and the inverse transformation is ```python def inverse(y): @@ -121,7 +122,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this also proves the transform is bijective.) - #### Example Use + #### Examples ```python tfd = tf.contrib.distributions @@ -142,7 +143,8 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): maf.log_prob(x) # Almost free; uses Bijector caching. maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - # [1] also describes an "Inverse Autoregressive Flow", e.g., + # [Papamakarios et al. (2016)][3] also describe an Inverse Autoregressive + # Flow [(Kingma et al., 2016)][2]: iaf = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( @@ -168,14 +170,20 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): event_shape=[dims]) ``` - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + #### References - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 + [2]: Diederik P. Kingma, Tim Salimans, Rafal Jozefowicz, Xi Chen, Ilya + Sutskever, and Max Welling. Improving Variational Inference with Inverse + Autoregressive Flow. In _Neural Information Processing Systems_, 2016. + https://arxiv.org/abs/1606.04934 + + [3]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, @@ -212,6 +220,7 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector): self._shift_and_log_scale_fn = shift_and_log_scale_fn self._unroll_loop = unroll_loop super(MaskedAutoregressiveFlow, self).__init__( + forward_min_event_ndims=1, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -329,11 +338,7 @@ def masked_dense(inputs, **kwargs): """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 + See [Germain et al. (2015)][1] for detailed explanation. Arguments: inputs: Tensor input. @@ -358,6 +363,12 @@ def masked_dense(inputs, Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ # TODO(b/67594795): Better support of dynamic shape. input_depth = inputs.shape.with_rank_at_least(1)[-1].value @@ -398,23 +409,24 @@ def masked_autoregressive_default_template( name=None, *args, **kwargs): - """Build the MADE Model [1]. + """Build the Masked Autoregressive Density Estimator (Germain et al., 2015). This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. + created once. It takes the input and returns the `loc` ("mu" in [Germain et + al. (2015)][1]) and `log_scale` ("alpha" in [Germain et al. (2015)][1]) from + the MADE network. Warning: This function uses `masked_dense` to create randomly initialized `tf.Variables`. It is presumed that these will be fit, just as you would any other neural architecture which uses `tf.layers.dense`. - #### About Hidden Layers: + #### About Hidden Layers Each element of `hidden_layers` should be greater than the `input_depth` (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the neural network). This is necessary to ensure the autoregressivity property. - #### About Clipping: + #### About Clipping This function also optionally clips the `log_scale` (but possibly not its gradient). This is useful because if `log_scale` is too small/large it might @@ -427,11 +439,7 @@ def masked_autoregressive_default_template( `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual `grad[clip(x)] exp(clip(x))`. - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: + Args: hidden_layers: Python `list`-like of non-negative integer, scalars indicating the number of units in each hidden layer. Default: `[512, 512]. shift_only: Python `bool` indicating if only the `shift` term shall be @@ -450,12 +458,20 @@ def masked_autoregressive_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms (the "mu" in + [Germain et al. (2015)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in + [Germain et al. (2015)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: Mathieu Germain, Karol Gregor, Iain Murray, and Hugo Larochelle. MADE: + Masked Autoencoder for Distribution Estimation. In _International + Conference on Machine Learning_, 2015. https://arxiv.org/abs/1502.03509 """ with ops.name_scope(name, "masked_autoregressive_default_template", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 8654cc39d0c41ec4f1b85cd5fc4366ceaf4b224d..4978167803fc38b112c95922519c8c296cee2561 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -114,6 +114,7 @@ class Permute(bijector_lib.Bijector): ], permutation) self._permutation = permutation super(Permute, self).__init__( + forward_min_event_ndims=1, is_constant_jacobian=True, validate_args=validate_args, name=name or "permute") @@ -132,7 +133,10 @@ class Permute(bijector_lib.Bijector): axis=-1) def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) + # is_constant_jacobian = True for this bijector, hence the + # `log_det_jacobian` need only be specified for a single input, as this will + # be tiled to match `event_ndims`. + return constant_op.constant(0., dtype=y.dtype.base_dtype) def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) + return constant_op.constant(0., dtype=x.dtype.base_dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index c37db61720d10949f294ff7b2e9778ba6efa57f0..71f123f2a998458edaa9c8da07ea2932f62625ca 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -43,7 +43,6 @@ class PowerTransform(bijector.Bijector): def __init__(self, power=0., - event_ndims=0, validate_args=False, name="power_transform"): """Instantiates the `PowerTransform` bijector. @@ -51,8 +50,6 @@ class PowerTransform(bijector.Bijector): Args: power: Python `float` scalar indicating the transform power, i.e., `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -70,7 +67,7 @@ class PowerTransform(bijector.Bijector): raise ValueError("`power` must be a non-negative TF constant.") self._power = power super(PowerTransform, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -97,18 +94,13 @@ class PowerTransform(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) + return (self.power - 1.) * math_ops.log(y) def _forward_log_det_jacobian(self, x): x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) + return x + return (1. / self.power - 1.) * math_ops.log1p(x * self.power) def _maybe_assert_valid_x(self, x): if not self.validate_args or self.power == 0.: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 2840f52e742eac5e9e37a576bf7f6d6f05a07a35..f09ab21bce100e9dafb77eff1f3999ce4b71c681 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -38,7 +38,7 @@ class RealNVP(bijector_lib.Bijector): """RealNVP "affine coupling layer" for vector-valued events. Real NVP models a normalizing flow on a `D`-dimensional distribution via a - single `D-d`-dimensional conditional distribution [1]: + single `D-d`-dimensional conditional distribution [(Dinh et al., 2017)][1]: `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])` `y[0:d] = x[0:d]` @@ -51,31 +51,34 @@ class RealNVP(bijector_lib.Bijector): Masking is currently only supported for base distributions with `event_ndims=1`. For more sophisticated masking schemes like checkerboard or - channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired - masked units into the first `d` units. For base distributions with - `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape. - - Recall that the MAF bijector [2] implements a normalizing flow via an - autoregressive transformation. MAF and IAF have opposite computational - tradeoffs - MAF can train all units in parallel but must sample units - sequentially, while IAF must train units sequentially but can sample in - parallel. In contrast, Real NVP can compute both forward and inverse - computations in parallel. However, the lack of an autoregressive + channel-wise masking [(Papamakarios et al., 2016)[4], use the `tfb.Permute` + bijector to re-order desired masked units into the first `d` units. For base + distributions with `event_ndims > 1`, use the `tfb.Reshape` bijector to + flatten the event shape. + + Recall that the MAF bijector [(Papamakarios et al., 2016)][4] implements a + normalizing flow via an autoregressive transformation. MAF and IAF have + opposite computational tradeoffs - MAF can train all units in parallel but + must sample units sequentially, while IAF must train units sequentially but + can sample in parallel. In contrast, Real NVP can compute both forward and + inverse computations in parallel. However, the lack of an autoregressive transformations makes it less expressive on a per-bijector basis. A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or - "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable - with the arguments to `forward` and `inverse`, i.e., such that the - calculations in `forward`, `inverse` [below] are possible. For convenience, + "mu" in [Papamakarios et al. (2016)][4]) and `log(scale)` (aka "alpha" in + [Papamakarios et al. (2016)][4]) such that each are broadcastable with the + arguments to `forward` and `inverse`, i.e., such that the calculations in + `forward`, `inverse` [below] are possible. For convenience, `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` function. - NICE [3] is a special case of the Real NVP bijector which discards the scale - transformation, resulting in a constant-time inverse-log-determinant-Jacobian. - To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should - return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in - the `RealNVP` constructor. Calling `real_nvp_default_template` with - `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. + NICE [(Dinh et al., 2014)][2] is a special case of the Real NVP bijector + which discards the scale transformation, resulting in a constant-time + inverse-log-determinant-Jacobian. To use a NICE bijector instead of Real + NVP, `shift_and_log_scale_fn` should return `(shift, None)`, and + `is_constant_jacobian` should be set to `True` in the `RealNVP` constructor. + Calling `real_nvp_default_template` with `shift_only=True` returns one such + NICE-compatible `shift_and_log_scale_fn`. Caching: the scalar input depth `D` of the base distribution is not known at construction time. The first call to any of `forward(x)`, `inverse(x)`, @@ -103,23 +106,24 @@ class RealNVP(bijector_lib.Bijector): nvp.log_prob(0.) ``` - For more examples, see [4]. + For more examples, see [Jang (2018)][3]. - [1]: "Density Estimation using Real NVP." - Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. - https://arxiv.org/abs/1605.08803 + #### References - [2]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 + [1]: Laurent Dinh, Jascha Sohl-Dickstein, and Samy Bengio. Density Estimation + using Real NVP. In _International Conference on Learning + Representations_, 2017. https://arxiv.org/abs/1605.08803 - [3]: "NICE: Non-linear Independent Components Estimation." - Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. - https://arxiv.org/abs/1410.8516 + [2]: Laurent Dinh, David Krueger, and Yoshua Bengio. NICE: Non-linear + Independent Components Estimation. _arXiv preprint arXiv:1410.8516_, + 2014. https://arxiv.org/abs/1410.8516 - [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." - Eric Jang. Blog post. January 2018. - http://blog.evjang.com/2018/01/nf2.html + [3]: Eric Jang. Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows. + _Technical Report_, 2018. http://blog.evjang.com/2018/01/nf2.html + + [4]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ def __init__(self, @@ -162,7 +166,7 @@ class RealNVP(bijector_lib.Bijector): self._input_depth = None self._shift_and_log_scale_fn = shift_and_log_scale_fn super(RealNVP, self).__init__( - event_ndims=1, + forward_min_event_ndims=1, is_constant_jacobian=is_constant_jacobian, validate_args=validate_args, name=name) @@ -220,7 +224,7 @@ class RealNVP(bijector_lib.Bijector): _, log_scale = self._shift_and_log_scale_fn( x0, self._input_depth - self._num_masked) if log_scale is None: - return constant_op.constant(0., dtype=x.dtype, name="ildj") + return constant_op.constant(0., dtype=x.dtype, name="fldj") return math_ops.reduce_sum(log_scale, axis=-1) @@ -250,12 +254,20 @@ def real_nvp_default_template( **kwargs: `tf.layers.dense` keyword arguments. Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + shift: `Float`-like `Tensor` of shift terms ("mu" in + [Papamakarios et al. (2016)][1]). + log_scale: `Float`-like `Tensor` of log(scale) terms ("alpha" in + [Papamakarios et al. (2016)][1]). Raises: NotImplementedError: if rightmost dimension of `inputs` is unknown prior to graph execution. + + #### References + + [1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked + Autoregressive Flow for Density Estimation. In _Neural Information + Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ with ops.name_scope(name, "real_nvp_default_template"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 55eca063126797d577653f0d6bcdfddf8192bdb5..f21b982ba664b312c716827c7925767a0b5a037a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -128,15 +128,17 @@ class Reshape(bijector_lib.Bijector): self._event_shape_in = event_shape_in self._event_shape_out = event_shape_out - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") + super(Reshape, self).__init__( + forward_min_event_ndims=0, + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") def _maybe_check_valid_shape(self, shape, validate_args): """Check that a shape Tensor is int-type and otherwise sane.""" if not shape.dtype.is_integer: raise TypeError("{} dtype ({}) should be `int`-like.".format( - shape.op.name, shape.dtype.name)) + shape, shape.dtype.name)) assertions = [] @@ -144,10 +146,10 @@ class Reshape(bijector_lib.Bijector): ndims_ = tensor_util.constant_value(ndims) if ndims_ is not None and ndims_ > 1: raise ValueError("`{}` rank ({}) should be <= 1.".format( - shape.op.name, ndims_)) + shape, ndims_)) elif validate_args: assertions.append(check_ops.assert_less_equal( - ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + ndims, 1, message="`{}` rank should be <= 1.".format(shape))) shape_ = tensor_util.constant_value_as_shape(shape) if shape_.is_fully_defined(): @@ -155,12 +157,12 @@ class Reshape(bijector_lib.Bijector): if sum(es == -1) > 1: raise ValueError( "`{}` must have at most one `-1` (given {})" - .format(shape.op.name, es)) + .format(shape, es)) if np.any(es < -1): raise ValueError( "`{}` elements must be either positive integers or `-1`" "(given {})." - .format(shape.op.name, es)) + .format(shape, es)) elif validate_args: assertions.extend([ check_ops.assert_less_equal( @@ -168,11 +170,11 @@ class Reshape(bijector_lib.Bijector): math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), 1, message="`{}` elements must have at most one `-1`." - .format(shape.op.name)), + .format(shape)), check_ops.assert_greater_equal( shape, -1, message="`{}` elements must be either positive integers or `-1`." - .format(shape.op.name)), + .format(shape)), ]) return assertions diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index a640dfe7dfbcce96261589c7fc49107deaefdd54..5df8c886315ff75cdc884e3b9b4665fb64bb109d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -33,7 +33,9 @@ class Sigmoid(bijector.Bijector): def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) def _forward(self, x): return math_ops.sigmoid(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py deleted file mode 100644 index 223bc9d042c69be05b0e578835a31ed6e83c0c97..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SigmoidCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered - - -__all__ = [ - "SigmoidCentered", -] - - -class SigmoidCentered(softmax_centered.SoftmaxCentered): - """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). - - Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. - - See `bijector.SoftmaxCentered` for more details. - """ - - def __init__(self, validate_args=False, name="sigmoid_centered"): - super(SigmoidCentered, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 3a75e4ae9495793901b0da91a5aa3982aab35852..2a32e8abcde940b0056b0faf2955ec1b3bd71803 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -91,7 +91,6 @@ class SinhArcsinh(bijector.Bijector): def __init__(self, skewness=None, tailweight=None, - event_ndims=0, validate_args=False, name="SinhArcsinh"): """Instantiates the `SinhArcsinh` bijector. @@ -101,8 +100,6 @@ class SinhArcsinh(bijector.Bijector): of type `float32`. tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -125,7 +122,9 @@ class SinhArcsinh(bijector.Bijector): message="Argument tailweight was not positive") ], self._tailweight) super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) @property def skewness(self): @@ -149,31 +148,29 @@ class SinhArcsinh(bijector.Bijector): # dx/dy # = cosh(arcsinh(y) / tailweight - skewness) # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + return ( math_ops.log(math_ops.cosh( math_ops.asinh(y) / self.tailweight - self.skewness) # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) + - math_ops.log(self.tailweight)) def _forward_log_det_jacobian(self, x): # y = sinh((arcsinh(x) + skewness) * tailweight) # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), # dy/dx # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + return ( math_ops.log(math_ops.cosh( (math_ops.asinh(x) + self.skewness) * self.tailweight) # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) + + math_ops.log(self.tailweight)) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index a9dcce6c526600f3b26c6bceb730417000917ce7..f52b91550edff7390d8094a4508d862674e85d59 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,13 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -47,17 +42,14 @@ class SoftmaxCentered(bijector.Bijector): e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last coordinate. - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - Example Use: ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + bijector.SoftmaxCentered().forward(tf.log([2, 3, 4])) # Result: [0.2, 0.3, 0.4, 0.1] # Extra result: 0.1 - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + bijector.SoftmaxCentered().inverse([0.2, 0.3, 0.4, 0.1]) # Result: tf.log([2, 3, 4]) # Extra coordinate removed. ``` @@ -69,87 +61,50 @@ class SoftmaxCentered(bijector.Bijector): """ def __init__(self, - event_ndims=0, validate_args=False, name="softmax_centered"): self._graph_parents = [] self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=1, validate_args=validate_args, name=name) def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: + if input_shape.ndims is None or input_shape[-1] is None: return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + return tensor_shape.TensorShape([input_shape[-1] + 1]) def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 + return (input_shape[-1] + 1)[..., array_ops.newaxis] def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: + if output_shape.ndims is None or output_shape[-1] is None: return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) + if output_shape[-1] <= 1: + raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1]) + return tensor_shape.TensorShape([output_shape[-1] - 1]) def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] if self.validate_args: # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) + is_greater_one = check_ops.assert_greater( + output_shape[-1], 1, message="Need last dimension greater than 1.") + output_shape = control_flow_ops.with_dependencies( + [is_greater_one], output_shape) + return (output_shape[-1] - 1)[..., array_ops.newaxis] def _forward(self, x): # Pad the last dim with a zeros vector. We need this because it lets us # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - y = distribution_util.pad(y, axis=-1, back=True) + y = distribution_util.pad(x, axis=-1, back=True) # Set shape hints. if x.shape.ndims is not None: - shape = x.shape.as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) + shape = x.shape[:-1].concatenate(x.shape[-1] + 1) y.shape.assert_is_compatible_with(shape) y.set_shape(shape) - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). return nn_ops.softmax(y) def _inverse(self, y): @@ -161,42 +116,17 @@ class SoftmaxCentered(bijector.Bijector): # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) # = log(exp(x[i])/normalization) - log(y[end]) # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.shape.as_list(), dtype=np.int32) - if y.shape.is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = distribution_util.prefer_static_rank(y) # Do this first to make sure CSE catches that it'll happen again in # _inverse_log_det_jacobian. x = math_ops.log(y) - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + log_normalization = (-x[..., -1])[..., array_ops.newaxis] + x = x[..., :-1] + log_normalization # Set shape hints. if y.shape.ndims is not None: - shape = y.shape.as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) + shape = y.shape[:-1].concatenate(y.shape[-1] - 1) x.shape.assert_is_compatible_with(shape) x.set_shape(shape) @@ -222,19 +152,14 @@ class SoftmaxCentered(bijector.Bijector): return -math_ops.reduce_sum(math_ops.log(y), axis=-1) def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + return array_ops.squeeze( + (-log_normalization + math_ops.reduce_sum( + x - log_normalization, axis=-1, keepdims=True)), axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 81957fcf78922fa15fd20a25d144071f431161ae..96a938c803418ff818f9c531754b47ba1eb8667a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -62,7 +62,7 @@ class Softplus(bijector.Bijector): ```python # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) + softplus = Softplus() x = [[[1., 2], [3, 4]], [[5, 6], @@ -81,7 +81,6 @@ class Softplus(bijector.Bijector): "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) def __init__(self, - event_ndims=0, hinge_softness=None, validate_args=False, name="softplus"): @@ -101,7 +100,7 @@ class Softplus(bijector.Bijector): [nonzero_check], self.hinge_softness) super(Softplus, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -130,14 +129,12 @@ class Softplus(bijector.Bijector): # 1 - exp{-Y} approx Y. if self.hinge_softness is not None: y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) + return -math_ops.log(-math_ops.expm1(-y)) def _forward_log_det_jacobian(self, x): if self.hinge_softness is not None: x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) + return -nn_ops.softplus(-x) @property def hinge_softness(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a658c171b8313358754228aabbfa4bf93fd84d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py @@ -0,0 +1,86 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Softsign bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "Softsign", +] + + +class Softsign(bijector.Bijector): + """Bijector which computes `Y = g(X) = X / (1 + |X|)`. + + The softsign `Bijector` has the following two useful properties: + + * The domain is all real numbers + * `softsign(x) approx sgn(x)`, for large `|x|`. + + #### Examples + + ```python + # Create the Y = softsign(X) transform. + softsign = Softsign() + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + x / (1 + abs(x)) == softsign.forward(x) + x / (1 - abs(x)) == softsign.inverse(x) + ``` + """ + + def __init__(self, validate_args=False, name="softsign"): + super(Softsign, self).__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return x / (1. + math_ops.abs(x)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return y / (1. - math_ops.abs(y)) + + def _forward_log_det_jacobian(self, x): + return -2. * math_ops.log1p(math_ops.abs(x)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + return -2. * math_ops.log1p(-math_ops.abs(y)) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = [ + check_ops.assert_greater( + y, math_ops.cast(-1., dtype=y.dtype.base_dtype), + message="Inverse transformation input must be greater than -1."), + check_ops.assert_less( + y, math_ops.cast(1., dtype=y.dtype.base_dtype), + message="Inverse transformation input must be less than 1.") + ] + + return control_flow_ops.with_dependencies(is_valid, y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py new file mode 100644 index 0000000000000000000000000000000000000000..2ccfdc95970e387e708603e2614ad29fb6a18db3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -0,0 +1,84 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Square bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "Square", +] + + +class Square(bijector.Bijector): + """Compute `g(X) = X^2`; X is a positive real number. + + g is a bijection between the non-negative real numbers (R_+) and the + non-negative real numbers. + + #### Examples + + ```python + bijector.Square().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [4, 1]], i.e., x^2 + + bijector.Square().inverse(y=[[1., 4], [9, 1]]) + # Result: [[1., 2], [3, 1]], i.e., sqrt(y). + ``` + + """ + + def __init__(self, validate_args=False, name="square"): + """Instantiates the `Square` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._name = name + super(Square, self).__init__( + forward_min_event_ndims=0, + validate_args=validate_args, + name=name) + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.square(x) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.sqrt(y) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid(x) + return np.log(2.) + math_ops.log(x) + + def _maybe_assert_valid(self, t): + if not self.validate_args: + return t + is_valid = check_ops.assert_non_negative( + t, message="All elements must be non-negative.") + return control_flow_ops.with_dependencies([is_valid], t) + diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index 00520bcda85e9527767e6342bf75f10667c264a8..39129cd22cdbf9ca1b4edd7cb5c3571a33837a29 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -50,7 +50,6 @@ class Weibull(bijector.Bijector): def __init__(self, scale=1., concentration=1., - event_ndims=0, validate_args=False, name="weibull"): """Instantiates the `Weibull` bijector. @@ -62,8 +61,6 @@ class Weibull(bijector.Bijector): concentration: Positive Float-type `Tensor` that is the same dtype and is broadcastable with `scale`. This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be checked for correctness. name: Python `str` name given to ops managed by this object. @@ -89,7 +86,7 @@ class Weibull(bijector.Bijector): ], self._concentration) super(Weibull, self).__init__( - event_ndims=event_ndims, + forward_min_event_ndims=0, validate_args=validate_args, name=name) @@ -113,22 +110,18 @@ class Weibull(bijector.Bijector): def _inverse_log_det_jacobian(self, y): y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( + return ( -math_ops.log1p(-y) + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) + math_ops.log(self.scale / self.concentration)) def _forward_log_det_jacobian(self, x): x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( + return ( -(x / self.scale) ** self.concentration + (self.concentration - 1) * math_ops.log(x) + math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) + -self.concentration * math_ops.log(self.scale)) def _maybe_assert_valid_x(self, x): if not self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index bdd5571c966a74e58e4f9f8eed2628f131a1b92e..e610f469e5d5f446b75c734cc39811de30a8cb9a 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -21,6 +21,8 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma @@ -87,7 +89,11 @@ class Chi2(gamma.Gamma): # allow_nan_stats=True # through to the parent class results in unnecessary asserts. with ops.name_scope(name, values=[df]): - self._df = ops.convert_to_tensor(df, name="df") + with ops.control_dependencies([ + check_ops.assert_positive(df), + ] if validate_args else []): + self._df = array_ops.identity(df, name="df") + super(Chi2, self).__init__( concentration=0.5 * self._df, rate=constant_op.constant(0.5, dtype=self._df.dtype), diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea..10b45361358b40a3c8fd725f27ad84ef9b8a37f5 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -105,7 +106,9 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + event_ndims = self._maybe_get_event_ndims_statically() + ildj = self.bijector.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_log_prob_for_one_fiber(y, x, ildj, distribution_kwargs) @@ -128,7 +131,9 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + event_ndims = self._maybe_get_event_ndims_statically() + ildj = self.bijector.inverse_log_det_jacobian( + y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access return self._finish_prob_for_one_fiber(y, x, ildj, distribution_kwargs) @@ -214,3 +219,15 @@ class ConditionalTransformedDistribution( # implies the qth quantile of Y is g(x_q). inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) + + def _maybe_get_event_ndims_statically(self): + if self.event_shape.ndims is not None: + return self.event_shape.ndims + + event_ndims = array_ops.size(self.event_shape_tensor()) + static_event_ndims = tensor_util.constant_value(event_ndims) + + if static_event_ndims is not None: + return static_event_ndims + + return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 6b53338c4542c75d3977c075b7750c780080ac48..98edd337fe02ffbf53c6ecd9ebda9424231ea2fe 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -75,7 +75,7 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): - """Creates a _RegressionHead instance from an arbitray `Distribution`.""" + """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" def __init__(self, make_distribution_fn, diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5..8d05ad6b8032fb8bada99389959091fb1c28beda 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -190,9 +190,6 @@ class _Gumbel(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return -math_ops.exp(-self._z(x)) diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index cbce005013281ff3c58c94d525d5ce7a865d725a..b1bacb91b03093fa93a7e5f7eb855dc944dafb44 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import kullback_leibler class Independent(distribution_lib.Distribution): @@ -35,7 +36,7 @@ class Independent(distribution_lib.Distribution): This distribution is useful for regarding a collection of independent, non-identical distributions as a single random variable. For example, the - `Indpendent` distribution composed of a collection of `Bernoulli` + `Independent` distribution composed of a collection of `Bernoulli` distributions might define a distribution over an image (where each `Bernoulli` is a distribution over each pixel). @@ -254,3 +255,58 @@ class Independent(distribution_lib.Distribution): else: which_maximum = np.maximum return which_maximum(0, ndims - 1) + + +@kullback_leibler.RegisterKL(Independent, Independent) +def _kl_independent(a, b, name="kl_independent"): + """Batched KL divergence `KL(a || b)` for Independent distributions. + + We can leverage the fact that + ``` + KL(Independent(a) || Independent(b)) = sum(KL(a || b)) + ``` + where the sum is over the `reinterpreted_batch_ndims`. + + Args: + a: Instance of `Independent`. + b: Instance of `Independent`. + name: (optional) name to use for created ops. Default "kl_independent". + + Returns: + Batchwise `KL(a || b)`. + + Raises: + ValueError: If the event space for `a` and `b`, or their underlying + distributions don't match. + """ + p = a.distribution + q = b.distribution + + # The KL between any two (non)-batched distributions is a scalar. + # Given that the KL between two factored distributions is the sum, i.e. + # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute + # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions. + if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined(): + if a.event_shape == b.event_shape: + if p.event_shape == q.event_shape: + num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims + reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)] + + return math_ops.reduce_sum( + kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) + else: + raise NotImplementedError("KL between Independents with different " + "event shapes not supported.") + else: + raise ValueError("Event shapes do not match.") + else: + with ops.control_dependencies([ + check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()), + check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor()) + ]): + num_reduce_dims = ( + array_ops.shape(a.event_shape_tensor()[0]) - + array_ops.shape(p.event_shape_tensor()[0])) + reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1) + return math_ops.reduce_sum( + kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index ee4d86867d48b20e97757bcec57d452085814b80..51ac61dcf640ca89f22c47127bda71316a179ca4 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -192,12 +192,6 @@ class InverseGamma(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - - def _log_cdf(self, x): - return math_ops.log(self._cdf(x)) - def _cdf(self, x): x = self._maybe_assert_valid_sample(x) # Note that igammac returns the upper regularized incomplete gamma diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 74d5d8773cf3e69a52554c87d656fea2835c8354..192dede6ff1d4de8d4be9965c414e7453d7b5d4b 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -20,15 +20,17 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.distributions import beta from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.distributions import uniform from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -42,25 +44,23 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. - Derivation from [1] and Euler's constant [2]. - [1] - - https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers - [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant - + Derivation from [here]( + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers) + and [Euler's constant]( + https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant). Args: x: input float. Returns: z: The analytic continuation of the harmonic number for the input. - """ one = array_ops.ones([], dtype=x.dtype) return math_ops.digamma(x + one) - math_ops.digamma(one) @tf_export("distributions.Kumaraswamy") -class Kumaraswamy(beta.Beta): +class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. The Kumaraswamy distribution is defined over the `(0, 1)` interval using @@ -151,59 +151,32 @@ class Kumaraswamy(beta.Beta): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ + concentration1 = ops.convert_to_tensor( + concentration1, name="concentration1") + concentration0 = ops.convert_to_tensor( + concentration0, name="concentration0") super(Kumaraswamy, self).__init__( - concentration1=concentration1, - concentration0=concentration0, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, + distribution=uniform.Uniform( + low=array_ops.zeros([], dtype=concentration1.dtype), + high=array_ops.ones([], dtype=concentration1.dtype), + allow_nan_stats=allow_nan_stats), + bijector=bijectors.Kumaraswamy( + concentration1=concentration1, concentration0=concentration0, + validate_args=validate_args), + batch_shape=distribution_util.get_broadcast_shape( + concentration1, concentration0), name=name) self._reparameterization_type = distribution.FULLY_REPARAMETERIZED - def _sample_n(self, n, seed=None): - expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 - expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 - shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) - uniform_sample = random_ops.random_uniform( - shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) - - kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( - 1. / expanded_concentration1) - return kumaraswamy_sample - - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _log_cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return math_ops.log1p(-(1 - x**a)**b) + @property + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self.bijector.concentration1 - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return 1 - (1 - x**a)**b - - def _survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return (1 - x**a)**b - - def _log_survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return b * math_ops.log1p(-x**a) - - def _log_unnormalized_prob(self, x): - x = self._maybe_assert_valid_sample(x) - a = self.concentration1 - b = self.concentration0 - return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) - - def _log_normalization(self): - a = self.concentration1 - b = self.concentration0 - return -(math_ops.log(a) + math_ops.log(b)) + @property + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self.bijector.concentration0 def _entropy(self): a = self.concentration1 @@ -213,10 +186,11 @@ class Kumaraswamy(beta.Beta): def _moment(self, n): """Compute the n'th (uncentered) moment.""" + total_concentration = self.concentration1 + self.concentration0 expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 + total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 + total_concentration, dtype=self.dtype) * self.concentration0 beta_arg0 = 1 + n / expanded_concentration1 beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( @@ -246,13 +220,14 @@ class Kumaraswamy(beta.Beta): name="nan") is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration1.dtype), self.concentration1, message="Mode undefined for concentration1 <= 1."), check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration0.dtype), self.concentration0, message="Mode undefined for concentration0 <= 1.") ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 473677f8d91b184e029f345bb05f5c5d63df7a40..68e6bca5a554b29a450911073eb5c4fe55f313c6 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -185,9 +185,6 @@ class Logistic(distribution.Distribution): def _log_prob(self, x): return self._log_unnormalized_prob(x) - self._log_normalization() - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _log_cdf(self, x): return -nn_ops.softplus(-self._z(x)) diff --git a/tensorflow/contrib/distributions/python/ops/moving_stats.py b/tensorflow/contrib/distributions/python/ops/moving_stats.py index 20f85643b9e7db61b4786dffe4115c7d3c00b046..87d40805a3c7a9c2871305af7f7182b7e2923530 100644 --- a/tensorflow/contrib/distributions/python/ops/moving_stats.py +++ b/tensorflow/contrib/distributions/python/ops/moving_stats.py @@ -47,9 +47,7 @@ def assign_moving_mean_variance( Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-1 mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Args: mean_var: `float`-like `Variable` representing the exponentially weighted @@ -72,6 +70,12 @@ def assign_moving_mean_variance( TypeError: if `mean_var` does not have float type `dtype`. TypeError: if `mean_var`, `variance_var`, `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ with ops.name_scope(name, "assign_moving_mean_variance", [variance_var, mean_var, value, decay]): @@ -183,9 +187,7 @@ def moving_mean_variance(value, decay, collections=None, name=None): Note: `mean_var` is updated *after* `variance_var`, i.e., `variance_var` uses the lag-`1` mean. - For derivation justification, see equation 143 of: - T. Finch, Feb 2009. "Incremental calculation of weighted mean and variance". - http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf + For derivation justification, see [Finch (2009; Eq. 143)][1]. Unlike `assign_moving_mean_variance`, this function handles variable creation. @@ -208,6 +210,12 @@ def moving_mean_variance(value, decay, collections=None, name=None): Raises: TypeError: if `value_var` does not have float type `dtype`. TypeError: if `value`, `decay` have different `base_dtype`. + + #### References + + [1]: Tony Finch. Incremental calculation of weighted mean and variance. + _Technical Report_, 2009. + http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf """ if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index b76cebf79fad09ebec68f2459c6fe80794ea81c0..e3e40b2e9ca232b9970768f21fb95887fdf0df2d 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -52,7 +52,7 @@ class OneHotCategorical(distribution.Distribution): #### Examples - Creates a 3-class distiribution, with the 2nd class, the most likely to be + Creates a 3-class distribution, with the 2nd class, the most likely to be drawn from. ```python @@ -60,7 +60,7 @@ class OneHotCategorical(distribution.Distribution): dist = OneHotCategorical(probs=p) ``` - Creates a 3-class distiribution, with the 2nd class the most likely to be + Creates a 3-class distribution, with the 2nd class the most likely to be drawn from, using logits. ```python @@ -203,9 +203,6 @@ class OneHotCategorical(distribution.Distribution): ret = array_ops.reshape(ret, logits_shape) return ret - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _entropy(self): return -math_ops.reduce_sum( nn_ops.log_softmax(self.logits) * self.probs, axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index e967dcc90d0712ffc346fb61ee67c44a6d9207cb..02e97c0a2fd004c4fa9382d5367af9f5b034a869 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -35,9 +35,15 @@ __all__ = [ _poisson_sample_note = """ -Note that the input value must be a non-negative floating point tensor with -dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only -legal if it is non-negative and its components are equal to integer values. +The Poisson distribution is technically only defined for non-negative integer +values. When `validate_args=False`, non-integral inputs trigger an assertion. + +When `validate_args=False` calculations are otherwise unchanged despite +integral or non-integral inputs. + +When `validate_args=False`, evaluating the pmf at non-integral values, +corresponds to evaluations of an unnormalized distribution, that does not +correspond to evaluations of the cdf. """ @@ -150,10 +156,6 @@ class Poisson(distribution.Distribution): def _cdf(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - else: - # Whether or not x is integer-form, the following is well-defined. - # However, scipy takes the floor, so we do too. - x = math_ops.floor(x) return math_ops.igammac(1. + x, self.rate) def _log_normalization(self): @@ -162,9 +164,6 @@ class Poisson(distribution.Distribution): def _log_unnormalized_prob(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - else: - # For consistency with cdf, we take the floor. - x = math_ops.floor(x) return x * self.log_rate - math_ops.lgamma(1. + x) def _mean(self): diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 92f2bba1828696248c9d9460566a08ba372c3358..3314181898870fa70dac3dfce42ba84de3d82a4a 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -114,7 +114,7 @@ def quadrature_scheme_lognormal_quantiles( # Create a LogNormal distribution. dist = transformed_lib.TransformedDistribution( distribution=normal_lib.Normal(loc=loc, scale=scale), - bijector=Exp(event_ndims=0), + bijector=Exp(), validate_args=validate_args) batch_ndims = dist.batch_shape.ndims if batch_ndims is None: diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index b525809015537ac8c7ee701c100fba6541fe2e92..e454a53c6275e0c60edd8c87b1c3be670f2b22de 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -35,10 +35,10 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): The RelaxedBernoulli is a distribution over the unit interval (0,1), which continuously approximates a Bernoulli. The degree of approximation is - controlled by a temperature: as the temperaturegoes to 0 the RelaxedBernoulli - becomes discrete with a distribution described by the `logits` or `probs` - parameters, as the temperature goes to infinity the RelaxedBernoulli - becomes the constant distribution that is identically 0.5. + controlled by a temperature: as the temperature goes to 0 the + RelaxedBernoulli becomes discrete with a distribution described by the + `logits` or `probs` parameters, as the temperature goes to infinity the + RelaxedBernoulli becomes the constant distribution that is identically 0.5. The RelaxedBernoulli distribution is a reparameterized continuous distribution that is the binary special case of the RelaxedOneHotCategorical diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 2aa771a71efe52c8d86d459f090ea8ee137c4487..02cf3c7992dc8cde3869ac9f12e7b4372cd6ea2c 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -285,9 +285,6 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): ret = array_ops.reshape(log_prob, logits_shape) return ret - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _assert_valid_sample(self, x): if not self.validate_args: return x @@ -306,7 +303,7 @@ class RelaxedOneHotCategorical( The RelaxedOneHotCategorical is a distribution over random probability vectors, vectors of positive real values that sum to one, which continuously approximates a OneHotCategorical. The degree of approximation is controlled by - a temperature: as the temperaturegoes to 0 the RelaxedOneHotCategorical + a temperature: as the temperature goes to 0 the RelaxedOneHotCategorical becomes discrete with a distribution described by the `logits` or `probs` parameters, as the temperature goes to infinity the RelaxedOneHotCategorical becomes the constant distribution that is identically the constant vector of @@ -412,5 +409,5 @@ class RelaxedOneHotCategorical( validate_args=validate_args, allow_nan_stats=allow_nan_stats) super(RelaxedOneHotCategorical, self).__init__(dist, - bijectors.Exp(event_ndims=1), + bijectors.Exp(), name=name) diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index dfc813361977c159d8d48f9d5b9ff03db5b4acdc..f5aaa5cf34abde3ea4d25de1ecf3adaef3f2a770 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -301,13 +302,16 @@ def percentile(x, with ops.name_scope(name, [x, q]): x = ops.convert_to_tensor(x, name="x") - q = math_ops.to_float(q, name="q") + # Double is needed here and below, else we get the wrong index if the array + # is huge along axis. + q = math_ops.to_double(q, name="q") _get_static_ndims(q, expect_ndims=0) if validate_args: q = control_flow_ops.with_dependencies([ - check_ops.assert_rank(q, 0), check_ops.assert_greater_equal(q, 0.), - check_ops.assert_less_equal(q, 100.) + check_ops.assert_rank(q, 0), + check_ops.assert_greater_equal(q, math_ops.to_double(0.)), + check_ops.assert_less_equal(q, math_ops.to_double(100.)) ], q) if axis is None: @@ -332,7 +336,7 @@ def percentile(x, y = _move_dims_to_flat_end(x, axis, x_ndims) frac_at_q_or_above = 1. - q / 100. - d = math_ops.to_float(array_ops.shape(y)[-1]) + d = math_ops.to_double(array_ops.shape(y)[-1]) if interpolation == "lower": index = math_ops.ceil((d - 1) * frac_at_q_or_above) @@ -341,12 +345,18 @@ def percentile(x, elif interpolation == "nearest": index = math_ops.round((d - 1) * frac_at_q_or_above) + # If d is gigantic, then we would have d == d - 1, even in double... So + # let's use max/min to avoid out of bounds errors. + d = array_ops.shape(y)[-1] + # d - 1 will be distinct from d in int32. + index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1) + # Sort everything, not just the top 'k' entries, which allows multiple calls # to sort only once (under the hood) and use CSE. sorted_y = _sort_tensor(y) # result.shape = B - result = sorted_y[..., math_ops.to_int32(index)] + result = sorted_y[..., index] result.set_shape(y.get_shape()[:-1]) if keep_dims: diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..056d349688511e19a4fa3d58a5b3c1c8355671a3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py @@ -0,0 +1,228 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Local PRNG for amplifying seed entropy into seeds for base operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib + + +class SeedStream(object): + """Local PRNG for amplifying seed entropy into seeds for base operations. + + Writing sampling code which correctly sets the pseudo-random number + generator (PRNG) seed is surprisingly difficult. This class serves as + a helper for the TensorFlow Probability coding pattern designed to + avoid common mistakes. + + # Motivating Example + + A common first-cut implementation of a sampler for the beta + distribution is to compute the ratio of a gamma with itself plus + another gamma. This code snippet tries to do that, but contains a + surprisingly common error: + + ```python + def broken_beta(shape, alpha, beta, seed): + x = tf.random_gamma(shape, alpha, seed=seed) + y = tf.random_gamma(shape, beta, seed=seed) + return x / (x + y) + ``` + + The mistake is that the two gamma draws are seeded with the same + seed. This causes them to always produce the same results, which, + in turn, leads this code snippet to always return `0.5`. Because it + can happen across abstraction boundaries, this kind of error is + surprisingly easy to make when handling immutable seeds. + + # Goals + + TensorFlow Probability adopts a code style designed to eliminate the + above class of error, without exacerbating others. The goals of + this code style are: + + - Support reproducibility of results (by encouraging seeding of all + pseudo-random operations). + + - Avoid shared-write global state (by not relying on a global PRNG). + + - Prevent accidental seed reuse by TF Probability implementers. This + goal is served with the local pseudo-random seed generator provided + in this module. + + - Mitigate potential accidental seed reuse by TF Probability clients + (with a salting scheme). + + - Prevent accidental resonances with downstream PRNGs (by hashing the + output). + + ## Non-goals + + - Implementing a high-performance PRNG for generating large amounts of + entropy. That's the job of the underlying TensorFlow PRNG we are + seeding. + + - Avoiding random seed collisions, aka "birthday attacks". + + # Code pattern + + ```python + def random_beta(shape, alpha, beta, seed): # (a) + seed = SeedStream(seed, salt="random_beta") # (b) + x = tf.random_gamma(shape, alpha, seed=seed()) # (c) + y = tf.random_gamma(shape, beta, seed=seed()) # (c) + return x / (x + y) + ``` + + The elements of this pattern are: + + - Accept an explicit seed (line a) as an argument in all public + functions, and write the function to be deterministic (up to any + numerical issues) for fixed seed. + + - Rationale: This provides the client with the ability to reproduce + results. Accepting an immutable seed rather than a mutable PRNG + object reduces code coupling, permitting different sections to be + reproducible independently. + + - Use that seed only to initialize a local `SeedStream` instance (line b). + + - Rationale: Avoids accidental seed reuse. + + - Supply the name of the function being implemented as a salt to the + `SeedStream` instance (line b). This serves to keep the salts + unique; unique salts ensure that clients of TF Probability will see + different functions always produce independent results even if + called with the same seeds. + + - Seed each callee operation with the output of a unique call to the + `SeedStream` instance (lines c). This ensures reproducibility of + results while preventing seed reuse across callee invocations. + + # Why salt? + + Salting the `SeedStream` instances (with unique salts) is defensive + programming against a client accidentally committing a mistake + similar to our motivating example. Consider the following situation + that might arise without salting: + + ```python + def tfp_foo(seed): + seed = SeedStream(seed, salt="") + foo_stuff = tf.random_normal(seed=seed()) + ... + + def tfp_bar(seed): + seed = SeedStream(seed, salt="") + bar_stuff = tf.random_normal(seed=seed()) + ... + + def client_baz(seed): + foo = tfp_foo(seed=seed) + bar = tfp_bar(seed=seed) + ... + ``` + + The client should have used different seeds as inputs to `foo` and + `bar`. However, because they didn't, *and because `foo` and `bar` + both sample a Gaussian internally as their first action*, the + internal `foo_stuff` and `bar_stuff` will be the same, and the + returned `foo` and `bar` will not be independent, leading to subtly + incorrect answers from the client's simulation. This kind of bug is + particularly insidious for the client, because it depends on a + Distributions implementation detail, namely the order in which `foo` + and `bar` invoke the samplers they depend on. In particular, a + Bayesflow team member can introduce such a bug in previously + (accidentally) correct client code by performing an internal + refactoring that causes this operation order alignment. + + A salting discipline eliminates this problem by making sure that the + seeds seen by `foo`'s callees will differ from those seen by `bar`'s + callees, even if `foo` and `bar` are invoked with the same input + seed. + """ + + def __init__(self, seed, salt): + """Initializes a `SeedStream`. + + Args: + seed: Any Python object convertible to string, supplying the + initial entropy. If `None`, operations seeded with seeds + drawn from this `SeedStream` will follow TensorFlow semantics + for not being seeded. + salt: Any Python object convertible to string, supplying + auxiliary entropy. Must be unique across the Distributions + and TensorFlow Probability code base. See class docstring for + rationale. + """ + self._seed = seed + self._salt = salt + self._counter = 0 + + def __call__(self): + """Returns a fresh integer usable as a seed in downstream operations. + + If this `SeedStream` was initialized with `seed=None`, returns + `None`. This has the effect that downstream operations (both + `SeedStream`s and primitive TensorFlow ops) will behave as though + they were unseeded. + + The returned integer is non-negative, and uniformly distributed in + the half-open interval `[0, 2**512)`. This is consistent with + TensorFlow, as TensorFlow operations internally use the residue of + the given seed modulo `2**31 - 1` (see + `tensorflow/python/framework/random_seed.py`). + + Returns: + seed: A fresh integer usable as a seed in downstream operations, + or `None`. + """ + self._counter += 1 + if self._seed is None: + return None + composite = str((self._seed, self._counter, self._salt)).encode("utf-8") + return int(hashlib.sha512(composite).hexdigest(), 16) + + @property + def original_seed(self): + return self._seed + + @property + def salt(self): + return self._salt + +# Design rationales for the SeedStream class +# +# - Salts are accepted for the reason given above to supply them. +# +# - A `None` seed propagates to downstream seeds, so they exhibit +# their "unseeded" behavior. +# +# - The return value is a Python int so it can be passed directly to +# TensorFlow operations as a seed. It is large to avoid losing seed +# space needlessly (TF will internally read only the last 31 bits). +# +# - The output is hashed with a crypto-grade hash function as a form +# of defensive programming: this reliably prevents all possible +# accidental resonances with all possible downstream PRNGs. The +# specific function used is not important; SHA512 was ready to hand. +# +# - The internal state update is a simple counter because (a) given +# that the output is hashed anyway, this is enough, and (b) letting +# it be this predictable permits a future "generate many seeds in +# parallel" operation whose results would agree with running +# sequentially. diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 5fb6f0c7eaa8c4734ea4c161b0eee6f24d4c9850..bac0b79d5908712f4e64259768fb6f3b4558f620 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -32,45 +32,50 @@ from tensorflow.python.ops.distributions import util as distribution_util class _DistributionShape(object): """Manage and manipulate `Distribution` shape. - Terminology: - Recall that a `Tensor` has: - - `shape`: size of `Tensor` dimensions, - - `ndims`: size of `shape`; number of `Tensor` dimensions, - - `dims`: indexes into `shape`; useful for transpose, reduce. - - `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, - `batch_dims`, and `event_dims`. To understand the semantics of these - dimensions, consider when two of the three are fixed and the remaining - is varied: - - `sample_dims`: indexes independent draws from identical - parameterizations of the `Distribution`. - - `batch_dims`: indexes independent draws from non-identical - parameterizations of the `Distribution`. - - `event_dims`: indexes event coordinates from one sample. - - The `sample`, `batch`, and `event` dimensions constitute the entirety of a - `Distribution` `Tensor`'s shape. - - The dimensions are always in `sample`, `batch`, `event` order. - - Purpose: - This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into - `Distribution` notions of `sample,` `batch,` and `event` dimensions. That - is, it computes any of: + #### Terminology - ``` - sample_shape batch_shape event_shape - sample_dims batch_dims event_dims - sample_ndims batch_ndims event_ndims - ``` + Recall that a `Tensor` has: + - `shape`: size of `Tensor` dimensions, + - `ndims`: size of `shape`; number of `Tensor` dimensions, + - `dims`: indexes into `shape`; useful for transpose, reduce. + + `Tensor`s sampled from a `Distribution` can be partitioned by `sample_dims`, + `batch_dims`, and `event_dims`. To understand the semantics of these + dimensions, consider when two of the three are fixed and the remaining + is varied: + - `sample_dims`: indexes independent draws from identical + parameterizations of the `Distribution`. + - `batch_dims`: indexes independent draws from non-identical + parameterizations of the `Distribution`. + - `event_dims`: indexes event coordinates from one sample. + + The `sample`, `batch`, and `event` dimensions constitute the entirety of a + `Distribution` `Tensor`'s shape. + + The dimensions are always in `sample`, `batch`, `event` order. + + #### Purpose + + This class partitions `Tensor` notions of `shape`, `ndims`, and `dims` into + `Distribution` notions of `sample,` `batch,` and `event` dimensions. That + is, it computes any of: + + ``` + sample_shape batch_shape event_shape + sample_dims batch_dims event_dims + sample_ndims batch_ndims event_ndims + ``` - for a given `Tensor`, e.g., the result of - `Distribution.sample(sample_shape=...)`. + for a given `Tensor`, e.g., the result of + `Distribution.sample(sample_shape=...)`. - For a given `Tensor`, this class computes the above table using minimal - information: `batch_ndims` and `event_ndims`. + For a given `Tensor`, this class computes the above table using minimal + information: `batch_ndims` and `event_ndims`. + + #### Examples + + We show examples of distribution shape semantics. - Examples of `Distribution` `shape` semantics: - Sample dimensions: Computing summary statistics, i.e., the average is a reduction over sample dimensions. @@ -111,52 +116,54 @@ class _DistributionShape(object): tf.div(1., tf.reduce_prod(x, event_dims)) ``` - Examples using this class: - Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. - - ```python - # 150 iid samples from one multivariate Normal with two degrees of freedom. - mu = [0., 0] - sigma = [[1., 0], - [0, 1]] - mvn = MultivariateNormal(mu, sigma) - rand_mvn = mvn.sample(sample_shape=[3, 50]) - shaper = DistributionShape(batch_ndims=0, event_ndims=1) - S, B, E = shaper.get_shape(rand_mvn) - # S = [3, 50] - # B = [] - # E = [2] - - # 12 iid samples from one Wishart with 2x2 events. - sigma = [[1., 0], - [2, 1]] - wishart = Wishart(df=5, scale=sigma) - rand_wishart = wishart.sample(sample_shape=[3, 4]) - shaper = DistributionShape(batch_ndims=0, event_ndims=2) - S, B, E = shaper.get_shape(rand_wishart) - # S = [3, 4] - # B = [] - # E = [2, 2] - - # 100 iid samples from two, non-identical trivariate Normal distributions. - mu = ... # shape(2, 3) - sigma = ... # shape(2, 3, 3) - X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) - # S = [4, 25] - # B = [2] - # E = [3] - ``` - - Argument Validation: - When `validate_args=False`, checks that cannot be done during - graph construction are performed at graph execution. This may result in a - performance degradation because data must be switched from GPU to CPU. - - For example, when `validate_args=False` and `event_ndims` is a - non-constant `Tensor`, it is checked to be a non-negative integer at graph - execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` - arguments are always checked for correctness since this can be done for - "free," i.e., during graph construction. + We show examples using this class. + + Write `S, B, E` for `sample_shape`, `batch_shape`, and `event_shape`. + + ```python + # 150 iid samples from one multivariate Normal with two degrees of freedom. + mu = [0., 0] + sigma = [[1., 0], + [0, 1]] + mvn = MultivariateNormal(mu, sigma) + rand_mvn = mvn.sample(sample_shape=[3, 50]) + shaper = DistributionShape(batch_ndims=0, event_ndims=1) + S, B, E = shaper.get_shape(rand_mvn) + # S = [3, 50] + # B = [] + # E = [2] + + # 12 iid samples from one Wishart with 2x2 events. + sigma = [[1., 0], + [2, 1]] + wishart = Wishart(df=5, scale=sigma) + rand_wishart = wishart.sample(sample_shape=[3, 4]) + shaper = DistributionShape(batch_ndims=0, event_ndims=2) + S, B, E = shaper.get_shape(rand_wishart) + # S = [3, 4] + # B = [] + # E = [2, 2] + + # 100 iid samples from two, non-identical trivariate Normal distributions. + mu = ... # shape(2, 3) + sigma = ... # shape(2, 3, 3) + X = MultivariateNormal(mu, sigma).sample(shape=[4, 25]) + # S = [4, 25] + # B = [2] + # E = [3] + ``` + + #### Argument Validation + + When `validate_args=False`, checks that cannot be done during + graph construction are performed at graph execution. This may result in a + performance degradation because data must be switched from GPU to CPU. + + For example, when `validate_args=False` and `event_ndims` is a + non-constant `Tensor`, it is checked to be a non-negative integer at graph + execution. (Same for `batch_ndims`). Constant `Tensor`s and non-`Tensor` + arguments are always checked for correctness since this can be done for + "free," i.e., during graph construction. """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index c4b8f055b7fbc3f0835b503eddd7617610326d8c..cde6d855009ff45129f603de1462f60b828e661f 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -166,21 +166,20 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): # Make the SAS bijector, 'F'. f = bijectors.SinhArcsinh( - skewness=skewness, tailweight=tailweight, event_ndims=0) + skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = bijectors.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), - tailweight=tailweight, event_ndims=0) + tailweight=tailweight) - # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) + # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) - affine = bijectors.Affine( + affine = bijectors.AffineScalar( shift=loc, - scale_identity_multiplier=c, - validate_args=validate_args, - event_ndims=0) + scale=c, + validate_args=validate_args) bijector = bijectors.Chain([affine, f]) diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..9c69435fac109914ff29b307dfad105f62849339 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -0,0 +1,838 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Statistical test assertions calibrated for their error rates. + +Statistical tests have an inescapable probability of error: a correct +sampler can still fail a test by chance, and an incorrect sampler can +still pass a test by chance. This library is about bounding both of +those error rates. This requires admitting a task-specific notion of +"discrepancy": Correct code will fail rarely, code that misbehaves by +more than the discrepancy will pass rarely, and nothing reliable can +be said about code that misbehaves, but misbehaves by less than the +discrepancy. + +# Example + +Consider testing that the mean of a scalar probability distribution P +is some expected constant. Suppose the support of P is the interval +`[0, 1]`. Then you might do this: + +```python +tfd = tf.contrib.distributions + +expected_mean = ... +num_samples = 5000 +samples = ... draw 5000 samples from P + +# Check that the mean looks right +check1 = tfd.assert_true_mean_equal_by_dkwm( + samples, low=0., high=1., expected=expected_mean, + false_fail_rate=1e-6) + +# Check that the difference in means detectable with 5000 samples is +# small enough +check2 = tf.assert_less( + tfd.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=1.0, + false_fail_rate=1e-6, false_pass_rate=1e-6), + 0.01) + +# Be sure to execute both assertion ops +sess.run([check1, check2]) +``` + +The second assertion is an instance of experiment design. It's a +deterministic computation (independent of the code under test) that +checks that `5000` samples is enough to reliably resolve mean +differences of `0.01` or more. Here "reliably" means that if the code +under test is correct, the probability of drawing an unlucky sample +that causes this test to fail is at most 1e-6; and if the code under +test is incorrect enough that its true mean is 0.01 more or less than +expected, then the probability of drawing a "lucky" sample that causes +the test to false-pass is also at most 1e-6. + +# Overview + +Every function in this library can be characterized in terms of: + +- The property being tested, such as the full density of the + distribution under test, or just its true mean, or a single + Bernoulli probability, etc. + +- The relation being asserted, e.g., whether the mean is less, more, + or equal to the given expected value. + +- The stochastic bound being relied upon, such as the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + or the CDF of the binomial distribution (for assertions about + Bernoulli probabilities). + +- The number of sample sets in the statistical test. For example, + testing equality of means has a one-sample variant, where the + expected mean is given exactly, and a two-sample variant, where the + expected mean is itself given by a set of samples (e.g., from an + alternative algorithm). + +- What operation(s) of the test are to be performed. Each test has + three of these: + + 1. `assert` executes the test. Specifically, it creates a TF op that + produces an error if it has enough evidence to prove that the + property under test is violated. These functions depend on the + desired false failure rate, because that determines the sizes of + appropriate confidence intervals, etc. + + 2. `min_discrepancy` computes the smallest difference reliably + detectable by that test, given the sample count and error rates. + What it's a difference of is test-specific. For example, a test + for equality of means would make detection guarantees about the + difference the true means. + + 3. `min_num_samples` computes the minimum number of samples needed + to reliably detect a given discrepancy with given error rates. + + The latter two are for experimental design, and are meant to be + usable either interactively or inline in the overall test method. + +This library follows a naming convention, to make room for every +combination of the above. A name mentions the operation first, then +the property, then the relation, then the bound, then, if the test +takes more than one set of samples, a token indicating this. For +example, `assert_true_mean_equal_by_dkwm` (which is implicitly +one-sample). Each name is a grammatically sound noun phrase (or verb +phrase, for the asserts). + +# Asymptotic properties + +The number of samples needed tends to scale as `O(1/discrepancy**2)` and +as `O(log(1/error_rate))`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + +__all__ = [ + "true_mean_confidence_interval_by_dkwm", + "assert_true_mean_equal_by_dkwm", + "min_discrepancy_of_true_means_detectable_by_dkwm", + "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_equal_by_dkwm_two_sample", + "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", + "min_num_samples_for_dkwm_mean_two_sample_test", +] + + +def _batch_sort_vector(x, ascending=True, name=None): + with ops.name_scope(name, "_batch_sort_vector", [x]): + x = ops.convert_to_tensor(x, name="x") + n = array_ops.shape(x)[-1] + if ascending: + y, _ = nn_ops.top_k(-x, k=n, sorted=True) + y = -y + else: + y, _ = nn_ops.top_k(x, k=n, sorted=True) + y.set_shape(x.shape) + return y + + +def _do_maximum_mean(samples, envelope, high, name=None): + """Common code between maximum_mean and minimum_mean.""" + with ops.name_scope(name, "do_maximum_mean", [samples, envelope, high]): + n = array_ops.rank(samples) + # Move the batch dimension of `samples` to the rightmost position, + # where the _batch_sort_vector function wants it. + perm = array_ops.concat([math_ops.range(1, n), [0]], axis=0) + samples = array_ops.transpose(samples, perm) + + samples = _batch_sort_vector(samples) + + # The maximum mean is given by taking `envelope`-worth of + # probability from the smallest samples and moving it to the + # maximum value. This amounts to: + # - ignoring the smallest k samples, where `k/n < envelope` + # - taking a `1/n - (envelope - k/n)` part of the index k sample + # - taking all the other samples + # - and adding `envelope * high` at the end. + # The following is a vectorized and batched way of computing this. + # `max_mean_contrib` is a mask implementing the previous. + batch_size = array_ops.shape(samples)[-1] + batch_size = math_ops.cast(batch_size, dtype=samples.dtype.base_dtype) + step = 1. / batch_size + cum_steps = step * math_ops.range( + 1, batch_size + 1, dtype=samples.dtype.base_dtype) + max_mean_contrib = clip_ops.clip_by_value( + cum_steps - envelope[..., array_ops.newaxis], + clip_value_min=0., + clip_value_max=step) + return math_ops.reduce_sum( + samples * max_mean_contrib, axis=-1) + envelope * high + + +def _maximum_mean(samples, envelope, high, name=None): + """Returns a stochastic upper bound on the mean of a scalar distribution. + + The idea is that if the true CDF is within an `eps`-envelope of the + empirical CDF of the samples, and the support is bounded above, then + the mean is bounded above as well. In symbols, + + ```none + sup_x(|F_n(x) - F(x)|) < eps + ``` + + The 0th dimension of `samples` is interpreted as independent and + identically distributed samples. The remaining dimensions are + broadcast together with `envelope` and `high`, and operated on + separately. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `envelope` and `high`. + envelope: Floating-point tensor of sizes of admissible CDF + envelopes (i.e., the `eps` above). + high: Floating-point tensor of upper bounds on the distributions' + supports. + name: A name for this operation (optional). + + Returns: + bound: Floating-point tensor of upper bounds on the true means. + + Raises: + InvalidArgumentError: If some `sample` is found to be larger than + the corresponding `high`. + """ + with ops.name_scope(name, "maximum_mean", [samples, envelope, high]): + samples = ops.convert_to_tensor(samples, name="samples") + envelope = ops.convert_to_tensor(envelope, name="envelope") + high = ops.convert_to_tensor(high, name="high") + + xmax = math_ops.reduce_max(samples, axis=[0]) + msg = "Given sample maximum value exceeds expectations" + check_op = check_ops.assert_less_equal(xmax, high, message=msg) + with ops.control_dependencies([check_op]): + return array_ops.identity(_do_maximum_mean(samples, envelope, high)) + + +def _minimum_mean(samples, envelope, low, name=None): + """Returns a stochastic lower bound on the mean of a scalar distribution. + + The idea is that if the true CDF is within an `eps`-envelope of the + empirical CDF of the samples, and the support is bounded below, then + the mean is bounded below as well. In symbols, + + ```none + sup_x(|F_n(x) - F(x)|) < eps + ``` + + The 0th dimension of `samples` is interpreted as independent and + identically distributed samples. The remaining dimensions are + broadcast together with `envelope` and `low`, and operated on + separately. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `envelope` and `low`. + envelope: Floating-point tensor of sizes of admissible CDF + envelopes (i.e., the `eps` above). + low: Floating-point tensor of lower bounds on the distributions' + supports. + name: A name for this operation (optional). + + Returns: + bound: Floating-point tensor of lower bounds on the true means. + + Raises: + InvalidArgumentError: If some `sample` is found to be smaller than + the corresponding `low`. + """ + with ops.name_scope(name, "minimum_mean", [samples, envelope, low]): + samples = ops.convert_to_tensor(samples, name="samples") + envelope = ops.convert_to_tensor(envelope, name="envelope") + low = ops.convert_to_tensor(low, name="low") + + xmin = math_ops.reduce_min(samples, axis=[0]) + msg = "Given sample minimum value falls below expectations" + check_op = check_ops.assert_greater_equal(xmin, low, message=msg) + with ops.control_dependencies([check_op]): + return - _do_maximum_mean(-samples, envelope, -low) + + +def _dkwm_cdf_envelope(n, error_rate, name=None): + """Computes the CDF envelope that the DKWM inequality licenses. + + The [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + gives a stochastic bound on the distance between the true cumulative + distribution function (CDF) of any distribution and its empirical + CDF. To wit, for `n` iid samples from any distribution with CDF F, + + ```none + P(sup_x |F_n(x) - F(x)| > eps) < 2exp(-2n eps^2) + ``` + + This function computes the envelope size `eps` as a function of the + number of samples `n` and the desired limit on the left-hand + probability above. + + Args: + n: Tensor of numbers of samples drawn. + error_rate: Floating-point tensor of admissible rates of mistakes. + name: A name for this operation (optional). + + Returns: + eps: Tensor of maximum distances the true CDF can be from the + empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and + as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and + `error_rate`. + """ + with ops.name_scope(name, "dkwm_cdf_envelope", [n, error_rate]): + n = math_ops.cast(n, dtype=error_rate.dtype) + return math_ops.sqrt(-gen_math_ops.log(error_rate / 2.) / (2. * n)) + + +def _check_shape_dominates(samples, parameters): + """Check that broadcasting `samples` against `parameters` does not expand it. + + Why? Because I want to be very sure that the samples tensor is not + accidentally enlarged by broadcasting against tensors that are + supposed to be describing the distribution(s) sampled from, lest the + sample counts end up inflated. + + Args: + samples: A Tensor whose shape is to be protected against broadcasting. + parameters: A list of Tensors who are parameters for the statistical test. + + Returns: + samples: Return original `samples` with control dependencies attached + to ensure no broadcasting. + """ + def check(t): + samples_batch_shape = array_ops.shape(samples)[1:] + broadcasted_batch_shape = array_ops.broadcast_dynamic_shape( + samples_batch_shape, array_ops.shape(t)) + # This rank check ensures that I don't get a wrong answer from the + # _shapes_ broadcasting against each other. + samples_batch_ndims = array_ops.size(samples_batch_shape) + ge = check_ops.assert_greater_equal( + samples_batch_ndims, array_ops.rank(t)) + eq = check_ops.assert_equal(samples_batch_shape, broadcasted_batch_shape) + return ge, eq + checks = list(itertools.chain(*[check(t) for t in parameters])) + with ops.control_dependencies(checks): + return array_ops.identity(samples) + + +def true_mean_confidence_interval_by_dkwm( + samples, low, high, error_rate=1e-6, name=None): + """Computes a confidence interval for the mean of a scalar distribution. + + In batch mode, computes confidence intervals for all distributions + in the batch (which need not be identically distributed). + + Relies on the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + The probability (over the randomness of drawing the given samples) + that any true mean is outside the corresponding returned interval is + no more than the given `error_rate`. The size of the intervals + scale as + `O(1 / sqrt(#samples))`, as `O(high - low)`, and as `O(-log(error_rate))`. + + Note that `error_rate` is a total error rate for all the confidence + intervals in the batch. As such, if the batch is nontrivial, the + error rate is not broadcast but divided (evenly) among the batch + members. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + low: Floating-point tensor of lower bounds on the distributions' + supports. + high: Floating-point tensor of upper bounds on the distributions' + supports. + error_rate: *Scalar* admissible total rate of mistakes. + name: A name for this operation (optional). + + Returns: + low: A floating-point tensor of stochastic lower bounds on the true means. + high: A floating-point tensor of stochastic upper bounds on the true means. + """ + with ops.name_scope( + name, "true_mean_confidence_interval_by_dkwm", + [samples, low, high, error_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + error_rate = ops.convert_to_tensor(error_rate, name="error_rate") + samples = _check_shape_dominates(samples, [low, high]) + check_ops.assert_scalar(error_rate) # Static shape + error_rate = _itemwise_error_rate(error_rate, [low, high], samples) + n = array_ops.shape(samples)[0] + envelope = _dkwm_cdf_envelope(n, error_rate) + min_mean = _minimum_mean(samples, envelope, low) + max_mean = _maximum_mean(samples, envelope, high) + return min_mean, max_mean + + +def _itemwise_error_rate( + total_error_rate, param_tensors, sample_tensor=None, name=None): + with ops.name_scope( + name, "itemwise_error_rate", + [total_error_rate, param_tensors, sample_tensor]): + result_shape = [1] + for p_tensor in param_tensors: + result_shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(p_tensor), result_shape) + if sample_tensor is not None: + result_shape = array_ops.broadcast_dynamic_shape( + array_ops.shape(sample_tensor)[1:], result_shape) + num_items = math_ops.reduce_prod(result_shape) + return total_error_rate / math_ops.cast( + num_items, dtype=total_error_rate.dtype) + + +def assert_true_mean_equal_by_dkwm( + samples, low, high, expected, false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is as expected. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the true mean of some distribution from which the given samples are + drawn is _not_ the given expected mean with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want to + check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point tensor of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + low: Floating-point tensor of lower bounds on the distributions' + supports. + high: Floating-point tensor of upper bounds on the distributions' + supports. + expected: Floating-point tensor of expected true means. + false_fail_rate: *Scalar* admissible total rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean is + outside the corresponding confidence interval. + """ + with ops.name_scope( + name, "assert_true_mean_equal_by_dkwm", + [samples, low, high, expected, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected = ops.convert_to_tensor(expected, name="expected") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates(samples, [low, high, expected]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, error_rate=false_fail_rate) + less_op = check_ops.assert_less( + min_mean, expected, message="Mean confidence interval too high") + with ops.control_dependencies([less_op]): + return check_ops.assert_greater( + max_mean, expected, message="Mean confidence interval too low") + + +def min_discrepancy_of_true_means_detectable_by_dkwm( + n, low, high, false_fail_rate, false_pass_rate, name=None): + """Returns the minimum mean discrepancy that a DKWM-based test can detect. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Note that `false_fail_rate` is a total false failure rate for all + the tests in the batch. As such, if the batch is nontrivial, each + member will demand more samples. The `false_pass_rate` is also + interpreted as a total, but is treated asymmetrically: If each test + in the batch detects its corresponding discrepancy with probability + at least `1 - false_pass_rate`, then running all those tests and + failing if any one fails will jointly detect all those discrepancies + with the same `false_pass_rate`. + + Args: + n: Tensor of numbers of samples to be drawn from the distributions + of interest. + low: Floating-point tensor of lower bounds on the distributions' + supports. + high: Floating-point tensor of upper bounds on the distributions' + supports. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + discr: Tensor of lower bounds on the distances between true + means detectable by a DKWM-based test. + + For each batch member `i`, of `K` total, drawing `n[i]` samples from + some scalar distribution supported on `[low[i], high[i]]` is enough + to detect a difference in means of size `discr[i]` or more. + Specifically, we guarantee that (a) if the true mean is the expected + mean, `assert_true_mean_equal_by_dkwm` will fail with probability at + most `false_fail_rate / K` (which amounts to `false_fail_rate` if + applied to the whole batch at once), and (b) if the true mean + differs from the expected mean by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` will pass with probability at most + `false_pass_rate`. + + The detectable discrepancy scales as + + - `O(high[i] - low[i])`, + - `O(1 / sqrt(n[i]))`, + - `O(-log(false_fail_rate/K))`, and + - `O(-log(false_pass_rate))`. + """ + with ops.name_scope( + name, "min_discrepancy_of_true_means_detectable_by_dkwm", + [n, low, high, false_fail_rate, false_pass_rate]): + n = ops.convert_to_tensor(n, name="n") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Algorithm: Assume a true CDF F. The DKWM inequality gives a + # stochastic bound on how far the observed empirical CDF F_n can be. + # Then, using the DKWM inequality again gives a stochastic bound on + # the farthest candidate true CDF F' that + # true_mean_confidence_interval_by_dkwm might consider. At worst, these + # errors may go in the same direction, so the distance between F and + # F' is bounded by the sum. + # On batching: false fail rates sum, so I need to reduce + # the input to account for the batching. False pass rates + # max, so I don't. + sampling_envelope = _dkwm_cdf_envelope(n, false_pass_rate) + false_fail_rate = _itemwise_error_rate(false_fail_rate, [n, low, high]) + analysis_envelope = _dkwm_cdf_envelope(n, false_fail_rate) + return (high - low) * (sampling_envelope + analysis_envelope) + + +def min_num_samples_for_dkwm_mean_test( + discrepancy, low, high, + false_fail_rate=1e-6, false_pass_rate=1e-6, name=None): + """Returns how many samples suffice for a one-sample DKWM mean test. + + To wit, returns an upper bound on the number of samples necessary to + guarantee detecting a mean difference of at least the given + `discrepancy`, with the given `false_fail_rate` and `false_pass_rate`, + using the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval) + on a scalar distribution supported on `[low, high]`. + + Args: + discrepancy: Floating-point tensor of desired upper limits on mean + differences that may go undetected with probability higher than + `1 - false_pass_rate`. + low: Tensor of lower bounds on the distributions' support. + high: Tensor of upper bounds on the distributions' support. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + n: Tensor of numbers of samples to be drawn from the distributions + of interest. + + The `discrepancy`, `low`, and `high` tensors must have + broadcast-compatible shapes. + + For each batch member `i`, of `K` total, drawing `n[i]` samples from + some scalar distribution supported on `[low[i], high[i]]` is enough + to detect a difference in means of size `discrepancy[i]` or more. + Specifically, we guarantee that (a) if the true mean is the expected + mean, `assert_true_mean_equal_by_dkwm` will fail with probability at + most `false_fail_rate / K` (which amounts to `false_fail_rate` if + applied to the whole batch at once), and (b) if the true mean + differs from the expected mean by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` will pass with probability at most + `false_pass_rate`. + + The required number of samples scales + as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, + `O(-log(false_pass_rate))`, and `O(1 / discrepancy[i]**2)`. + """ + with ops.name_scope( + name, "min_num_samples_for_dkwm_mean_test", + [low, high, false_fail_rate, false_pass_rate, discrepancy]): + discrepancy = ops.convert_to_tensor( + discrepancy, name="discrepancy") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Could choose to cleverly allocate envelopes, but this is sound. + envelope1 = discrepancy / (2. * (high - low)) + envelope2 = envelope1 + false_fail_rate = _itemwise_error_rate( + false_fail_rate, [low, high, discrepancy]) + n1 = -math_ops.log(false_fail_rate / 2.) / (2. * envelope1**2) + n2 = -math_ops.log(false_pass_rate / 2.) / (2. * envelope2**2) + return math_ops.maximum(n1, n2) + + +def assert_true_mean_equal_by_dkwm_two_sample( + samples1, low1, high1, samples2, low2, high2, + false_fail_rate=1e-6, name=None): + """Asserts the means of the given distributions are equal. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the means of the distributions from which the given samples are + drawn are _not_ equal with statistical significance `false_fail_rate` + or stronger, otherwise passes. If you also want to check that you + are gathering enough evidence that a pass is not spurious, see + `min_num_samples_for_dkwm_mean_two_sample_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm_two_sample`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples1: Floating-point tensor of samples from the + distribution(s) A. Entries are assumed IID across the 0th + dimension. The other dimensions must broadcast with `low1`, + `high1`, `low2`, and `high2`. + low1: Floating-point tensor of lower bounds on the supports of the + distributions A. + high1: Floating-point tensor of upper bounds on the supports of + the distributions A. + samples2: Floating-point tensor of samples from the + distribution(s) B. Entries are assumed IID across the 0th + dimension. The other dimensions must broadcast with `low1`, + `high1`, `low2`, and `high2`. + low2: Floating-point tensor of lower bounds on the supports of the + distributions B. + high2: Floating-point tensor of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* admissible total rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any pair of confidence + intervals true for corresponding true means do not overlap. + """ + with ops.name_scope( + name, "assert_true_mean_equal_by_dkwm_two_sample", + [samples1, low1, high1, samples2, low2, high2, false_fail_rate]): + samples1 = ops.convert_to_tensor(samples1, name="samples1") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + samples2 = ops.convert_to_tensor(samples2, name="samples2") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples1 = _check_shape_dominates(samples1, [low1, high1]) + samples2 = _check_shape_dominates(samples2, [low2, high2]) + compatible_samples = check_ops.assert_equal( + array_ops.shape(samples1)[1:], array_ops.shape(samples2)[1:]) + with ops.control_dependencies([compatible_samples]): + # Could in principle play games with cleverly allocating + # significance instead of the even split below. It may be possible + # to get tighter intervals, in order to obtain a higher power test. + # Any allocation strategy that depends only on the support bounds + # and sample counts should be valid; however, because the intervals + # scale as O(-log(false_fail_rate)), there doesn't seem to be much + # room to win. + min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( + samples1, low1, high1, false_fail_rate / 2.) + min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( + samples2, low2, high2, false_fail_rate / 2.) + # I want to assert + # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), + # but I think I only have and-combination of asserts, so use DeMorgan. + check_confidence_intervals_can_intersect = check_ops.assert_greater_equal( + max_mean_1, min_mean_2, message="Confidence intervals do not " + "intersect: samples1 has a smaller mean than samples2") + with ops.control_dependencies([check_confidence_intervals_can_intersect]): + return check_ops.assert_less_equal( + min_mean_1, max_mean_2, message="Confidence intervals do not " + "intersect: samples2 has a smaller mean than samples1") + + +def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + n1, low1, high1, n2, low2, high2, + false_fail_rate, false_pass_rate, name=None): + """Returns the minimum mean discrepancy for a two-sample DKWM-based test. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Note that `false_fail_rate` is a total false failure rate for all + the tests in the batch. As such, if the batch is nontrivial, each + member will demand more samples. The `false_pass_rate` is also + interpreted as a total, but is treated asymmetrically: If each test + in the batch detects its corresponding discrepancy with probability + at least `1 - false_pass_rate`, then running all those tests and + failing if any one fails will jointly detect all those discrepancies + with the same `false_pass_rate`. + + Args: + n1: Tensor of numbers of samples to be drawn from the distributions A. + low1: Floating-point tensor of lower bounds on the supports of the + distributions A. + high1: Floating-point tensor of upper bounds on the supports of + the distributions A. + n2: Tensor of numbers of samples to be drawn from the distributions B. + low2: Floating-point tensor of lower bounds on the supports of the + distributions B. + high2: Floating-point tensor of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + discr: Tensor of lower bounds on the distances between true means + detectable by a two-sample DKWM-based test. + + For each batch member `i`, of `K` total, drawing `n1[i]` samples + from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` + samples from scalar distribution B supported on `[low2[i], high2[i]]` + is enough to detect a difference in their true means of size + `discr[i]` or more. Specifically, we guarantee that (a) if their + true means are equal, `assert_true_mean_equal_by_dkwm_two_sample` + will fail with probability at most `false_fail_rate/K` (which + amounts to `false_fail_rate` if applied to the whole batch at once), + and (b) if their true means differ by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm_two_sample` will pass with + probability at most `false_pass_rate`. + + The detectable distribution scales as + + - `O(high1[i] - low1[i])`, `O(high2[i] - low2[i])`, + - `O(1 / sqrt(n1[i]))`, `O(1 / sqrt(n2[i]))`, + - `O(-log(false_fail_rate/K))`, and + - `O(-log(false_pass_rate))`. + """ + with ops.name_scope( + name, "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", + [n1, low1, high1, n2, low2, high2, false_fail_rate, false_pass_rate]): + n1 = ops.convert_to_tensor(n1, name="n1") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + n2 = ops.convert_to_tensor(n2, name="n2") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + det_disc1 = min_discrepancy_of_true_means_detectable_by_dkwm( + n1, low1, high1, false_fail_rate / 2., false_pass_rate / 2.) + det_disc2 = min_discrepancy_of_true_means_detectable_by_dkwm( + n2, low2, high2, false_fail_rate / 2., false_pass_rate / 2.) + return det_disc1 + det_disc2 + + +def min_num_samples_for_dkwm_mean_two_sample_test( + discrepancy, low1, high1, low2, high2, + false_fail_rate=1e-6, false_pass_rate=1e-6, name=None): + """Returns how many samples suffice for a two-sample DKWM mean test. + + DKWM is the [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). + + Args: + discrepancy: Floating-point tensor of desired upper limits on mean + differences that may go undetected with probability higher than + `1 - false_pass_rate`. + low1: Floating-point tensor of lower bounds on the supports of the + distributions A. + high1: Floating-point tensor of upper bounds on the supports of + the distributions A. + low2: Floating-point tensor of lower bounds on the supports of the + distributions B. + high2: Floating-point tensor of upper bounds on the supports of + the distributions B. + false_fail_rate: *Scalar* admissible total rate of false failures. + false_pass_rate: *Scalar* admissible rate of false passes. + name: A name for this operation (optional). + + Returns: + n1: Tensor of numbers of samples to be drawn from the distributions A. + n2: Tensor of numbers of samples to be drawn from the distributions B. + + For each batch member `i`, of `K` total, drawing `n1[i]` samples + from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` + samples from scalar distribution B supported on `[low2[i], high2[i]]` + is enough to detect a difference in their true means of size + `discr[i]` or more. Specifically, we guarantee that (a) if their + true means are equal, `assert_true_mean_equal_by_dkwm_two_sample` + will fail with probability at most `false_fail_rate/K` (which + amounts to `false_fail_rate` if applied to the whole batch at once), + and (b) if their true means differ by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm_two_sample` will pass with + probability at most `false_pass_rate`. + + The required number of samples scales as + + - `O((high1[i] - low1[i])**2)`, `O((high2[i] - low2[i])**2)`, + - `O(-log(false_fail_rate/K))`, + - `O(-log(false_pass_rate))`, and + - `O(1 / discrepancy[i]**2)`. + """ + with ops.name_scope( + name, "min_num_samples_for_dkwm_mean_two_sample_test", + [low1, high1, low2, high2, + false_fail_rate, false_pass_rate, discrepancy]): + discrepancy = ops.convert_to_tensor(discrepancy, name="discrepancy") + low1 = ops.convert_to_tensor(low1, name="low1") + high1 = ops.convert_to_tensor(high1, name="high1") + low2 = ops.convert_to_tensor(low2, name="low2") + high2 = ops.convert_to_tensor(high2, name="high2") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + false_pass_rate = ops.convert_to_tensor( + false_pass_rate, name="false_pass_rate") + # Could choose to cleverly allocate discrepancy tolerances and + # failure probabilities, but this is sound. + n1 = min_num_samples_for_dkwm_mean_test( + discrepancy / 2., low1, high1, + false_fail_rate / 2., false_pass_rate / 2.) + n2 = min_num_samples_for_dkwm_mean_test( + discrepancy / 2., low2, high2, + false_fail_rate / 2., false_pass_rate / 2.) + return n1, n2 diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 0c747f8e68529484ae6f695b8500cde74857bb11..da271a852d715cd4bc3423b23e8a597b116027f0 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -181,7 +181,7 @@ def quadrature_scheme_softmaxnormal_quantiles( edges = array_ops.reshape(edges, shape=array_ops.concat([ [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) quantiles = dist.quantile(edges) - quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + quantiles = SoftmaxCentered().forward(quantiles) # Cyclically permute left by one. perm = array_ops.concat([ math_ops.range(1, 1 + batch_ndims), [0]], axis=0) @@ -248,11 +248,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of the quantiles of `p(z)` (generalized quantiles if `K > 2`). - See [1] for more details. - - [1]. "Quadrature Compound: An approximating family of distributions" - Joshua Dillon, Ian Langmore, arXiv preprints - https://arxiv.org/abs/1801.03080 + See [Dillon and Langmore (2018)][1] for more details. #### About `Vector` distributions in TensorFlow. @@ -313,6 +309,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): is_positive_definite=True), ], validate_args=True) + ``` + + #### References + + [1]: Joshua Dillon and Ian Langmore. Quadrature Compound: An approximating + family of distributions. _arXiv preprint arXiv:1801.03080_, 2018. + https://arxiv.org/abs/1801.03080 """ def __init__(self, @@ -424,7 +427,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._endpoint_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, - event_ndims=1, validate_args=validate_args, name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale))] @@ -464,7 +466,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._interpolated_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, - event_ndims=1, validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( @@ -618,9 +619,11 @@ class VectorDiffeomixture(distribution_lib.Distribution): log_prob = math_ops.reduce_sum(self.distribution.log_prob(y), axis=-2) # Because the affine transformation has a constant Jacobian, it is the case # that `affine.fldj(x) = -affine.ildj(x)`. This is not true in general. - fldj = array_ops.stack( - [aff.forward_log_det_jacobian(x) for aff in self.interpolated_affine], - axis=-1) + fldj = array_ops.stack([ + aff.forward_log_det_jacobian( + x, + event_ndims=array_ops.rank(self.event_shape_tensor()) + ) for aff in self.interpolated_affine], axis=-1) return math_ops.reduce_logsumexp( self.mixture_distribution.logits - fldj + log_prob, axis=-1) diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index e1ccf116457a97261b9ce3965552764771d3bdd2..05919be124e8fbfe29e8111a0637db072830ff61 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -215,19 +215,19 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): tailweight = ops.convert_to_tensor( tailweight, dtype=dtype, name="tailweight") f = bijectors.SinhArcsinh( - skewness=skewness, tailweight=tailweight, event_ndims=1) + skewness=skewness, tailweight=tailweight) if has_default_skewness: f_noskew = f else: f_noskew = bijectors.SinhArcsinh( skewness=skewness.dtype.as_numpy_dtype(0.), - tailweight=tailweight, event_ndims=0) + tailweight=tailweight) # Make the Affine bijector, Z --> loc + C * Z. c = 2 * scale_diag_part / f_noskew.forward( ops.convert_to_tensor(2, dtype=dtype)) affine = bijectors.Affine( - shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) + shift=loc, scale_diag=c, validate_args=validate_args) bijector = bijectors.Chain([affine, f]) diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 8c67647a618d22a58428d78865c4ebf7d98bdf9e..887981d64ef077e2636f8031581c390f177edac8 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -66,7 +66,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): This distribution is an Affine transformation of iid [Student's t-distributions]( https://en.wikipedia.org/wiki/Student%27s_t-distribution) - and should not be confused with the [Multivate Student's t-distribution]( + and should not be confused with the [Multivariate Student's t-distribution]( https://en.wikipedia.org/wiki/Multivariate_t-distribution). The traditional Multivariate Student's t-distribution is type of [elliptical distribution]( diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index e4ac65012b9c7e3ed5ada3ed75020f3905740156..5a8c94dabf4c3c430bee544a48ee7acfe7dd7ed0 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -228,9 +228,12 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbk) # This parametrization is equivalent to Chi2, i.e., # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2) + expanded_df = self.df * array_ops.ones( + self.scale_operator.batch_shape_tensor(), + dtype=self.df.dtype.base_dtype) g = random_ops.random_gamma(shape=[n], alpha=self._multi_gamma_sequence( - 0.5 * self.df, self.dimension), + 0.5 * expanded_df, self.dimension), beta=0.5, dtype=self.dtype, seed=distribution_util.gen_new_seed( diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index 9d2ca07c3a25fa7acb9b0f5806b763d9a57b51fa..9a3b780af888a597d2440b243ffb8dc98d764f18 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,12 +1,8 @@ # Eager Execution -> *WARNING*: This is a preview/pre-alpha version. The API and performance -> characteristics are subject to change. - -Eager execution is an experimental interface to TensorFlow that provides an -imperative programming style (à la [NumPy](http://www.numpy.org)). When you -enable eager execution, TensorFlow operations execute immediately; you do not -execute a pre-constructed graph with +Eager execution provides an imperative interface to TensorFlow (similiar to +[NumPy](http://www.numpy.org)). When you enable eager execution, TensorFlow +operations execute immediately; you do not execute a pre-constructed graph with [`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session). For example, consider a simple computation in TensorFlow: @@ -33,7 +29,7 @@ print(m) ## Caveats This feature is in early stages and work remains to be done in terms of smooth -support for distributed and multi-GPU training and CPU performance. +support for distributed and multi-GPU training and performance. - [Known issues](https://github.com/tensorflow/tensorflow/issues?q=is%3Aissue%20is%3Aopen%20label%3Acomp%3Aeager) - Feedback is welcome, please consider @@ -41,21 +37,23 @@ support for distributed and multi-GPU training and CPU performance. ## Installation -Eager execution is included in TensorFlow versions 1.5 and above. +Eager execution is included in TensorFlow versions 1.7 and above. Installation instructions at https://www.tensorflow.org/install/ ## Documentation For an introduction to eager execution in TensorFlow, see: -- [User Guide](python/g3doc/guide.md) +- [User Guide](https://www.tensorflow.org/programmers_guide/eager) ([source](../../docs_src/programmers_guide/eager.md)) - Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) - Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) - Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) ## Changelog -- 2017/10/31: Initial preview release. +- 2017/10/31: Initial preview release (in TensorFlow 1.5) - 2017/12/01: Example of dynamic neural network: [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). See [README.md](python/examples/spinn/README.md) for details. +- 2017/03: Core functionality moved out of the experimental tf.contrib namespace + in TensorFlow 1.7. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD deleted file mode 100644 index aedfec8924e7314addd22349c0576a84a58d9aa3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/proto/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -tf_proto_library( - name = "checkpointable_object_graph_proto", - srcs = [ - "checkpointable_object_graph.proto", - ], - visibility = ["//tensorflow/contrib/eager/python:__subpackages__"], -) diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto deleted file mode 100644 index 4f71aec96a2c3edee8a32b4e14584bd56ef3d439..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto +++ /dev/null @@ -1,57 +0,0 @@ -syntax = "proto3"; - -option cc_enable_arenas = true; - -package tensorflow.contrib.eager; - -// Prototype for an addition to BundleHeaderProto which saves extra information -// about the objects which own variables, allowing for more robust checkpoint -// loading into modified programs. - -message CheckpointableObjectGraph { - message Object { - message ObjectReference { - // An index into `CheckpointableObjectGraph.nodes`, indicating the object - // being referenced. - int32 node_id = 1; - // A user-provided name for the edge. - string local_name = 2; - } - - message VariableReference { - // A name for the variable which is unique within the object which owns - // it. Does not include a name_scope or variable_scope prefix. - string local_name = 1; - // The full name of the variable. Used to allow name-based loading of - // checkpoints which were saved using an object-based API. - string full_name = 2; - // The generated name of the variable in the checkpoint. - string checkpoint_key = 3; - } - - message SlotVariableReference { - // An index into `CheckpointableObjectGraph.nodes`, indicating the object - // which created the variable that this variable is slotting for. - int32 original_variable_node_id = 1; - // The local name of the variable being slotted for within the object that - // owns it. - string original_variable_local_name = 2; - // The name of the slot (e.g. "m"/"v"). - string slot_name = 3; - // The full name of the slot variable. Used to allow name-based loading of - // checkpoints which were saved using an object-based API. - string full_name = 4; - // The generated name of the variable in the checkpoint. - string checkpoint_key = 5; - } - - // Objects which this object depends on. - repeated ObjectReference children = 1; - // Non-slot variables owned by this object. - repeated VariableReference variables = 2; - // Slot variables owned by this object. - repeated SlotVariableReference slot_variables = 3; - } - - repeated Object nodes = 1; -} diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index cfb38a1d26c41a3923da7c989244a3d53b6a496b..e2744a430d1efe4b4a688dc7c5caff0bf83de358 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -11,12 +11,14 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":checkpointable_utils", ":datasets", ":metrics", ":network", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", @@ -26,7 +28,6 @@ py_library( "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", - "//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:execution_callbacks", "//tensorflow/python/eager:function", ], @@ -69,6 +70,10 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + ":checkpointable_utils", + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/contrib/data/python/ops:threadpool", + "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -77,6 +82,7 @@ cuda_py_test( "//tensorflow/python/data", "//tensorflow/python/eager:test", ], + tags = ["noguitar"], ) py_library( @@ -115,13 +121,14 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/contrib/eager/python:checkpointable_utils", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:summary_ops_v2", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:context", @@ -135,11 +142,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":metrics", - "//tensorflow/contrib/summary:summary_ops", "//tensorflow/contrib/summary:summary_test_util", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:summary_ops_v2", "//tensorflow/python:training", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -156,10 +163,10 @@ py_library( deps = [ ":datasets", ":metrics", - "//tensorflow/contrib/summary:summary_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:summary_ops_v2", "//tensorflow/python/eager:context", "//tensorflow/python/eager:function", "@six_archive//:six", @@ -220,32 +227,23 @@ py_test( ) py_library( - name = "checkpointable", - srcs = ["checkpointable.py"], + name = "checkpointable_utils", + srcs = ["checkpointable_utils.py"], srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:tensor_shape", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", ], ) -py_test( - name = "checkpointable_test", - srcs = ["checkpointable_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":checkpointable", +cuda_py_test( + name = "checkpointable_utils_test", + srcs = ["checkpointable_utils_test.py"], + additional_deps = [ + ":checkpointable_utils", ":network", + "@six_archive//:six", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -260,19 +258,10 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "@six_archive//:six", + "//tensorflow/python/keras", + ], + tags = [ + "no_windows", # TODO: needs investigation on Windows + "notsan", # b/74395663 ], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "g3doc/sitemap.md", - ], - ), - visibility = ["//tensorflow:__subpackages__"], ) diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py deleted file mode 100644 index 896b38a7348e1fdd5a13b197e3ee34f5c4c5a22c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable.py +++ /dev/null @@ -1,773 +0,0 @@ -"""An object-local variable management scheme.""" -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import re -import weakref - -from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training import slot_creator -from tensorflow.python.training import training - -_CheckpointableReference = collections.namedtuple( - "_CheckpointableReference", - [ - # The local name if explicitly specified, else None. - "name", - # The Checkpointable object being referenced. - "ref" - ]) - -# Validation regular expression for the local names of Checkpointable -# objects. In particular, disallows "/" in names, and reserves dash-prefixed -# names (which are not valid Python identifiers, so we're not restricting the -# __setattr__ syntax that way). -_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9_.][A-Za-z0-9_.-]*$") - -# Keyword for identifying that the next bit of a checkpoint variable name is a -# slot name. May not be the local name of a checkpointable. Checkpoint names for -# slot variables look like: -# -# /<_OPTIMIZER_SLOTS_NAME>// -# -# Where is a full path from the checkpoint root to the -# variable being slotted for. -_OPTIMIZER_SLOTS_NAME = "-OPTIMIZER_SLOT" - - -def _assign_existing_variable(variable_to_restore, value_pointer): - """Set a variable from a _ValuePointer object.""" - base_type = variable_to_restore.dtype.base_dtype - with ops.colocate_with(variable_to_restore): - # TODO(allenl): Handle partitioned variables - value_to_restore, = io_ops.restore_v2( - prefix=value_pointer.save_path, - tensor_names=[value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_type], - name="checkpoint_initializer") - initializer_op = state_ops.assign(variable_to_restore, value_to_restore) - variable_to_restore._initializer_op = initializer_op # pylint:disable=protected-access - if value_pointer.session is not None: - value_pointer.session.run(initializer_op) - - -def _default_getter(name, shape, dtype, initializer=None, - partition_info=None, **kwargs): - """A pared-down version of get_variable which does not reuse variables.""" - dtype = dtypes.as_dtype(dtype) - shape_object = tensor_shape.as_shape(shape) - with ops.init_scope(): - if initializer is None: - initializer, initializing_from_value = ( - variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access - name=name, shape=shape_object, dtype=dtype)) - else: - initializing_from_value = not callable(initializer) - # Same logic as get_variable - if initializing_from_value: - if shape is not None: - raise ValueError("If initializer is a constant, do not specify shape.") - initial_value = initializer - variable_dtype = None - else: - # Instantiate initializer if provided initializer is a type object. - if isinstance(initializer, type(init_ops.Initializer)): - initializer = initializer(dtype=dtype) - def initial_value(): - return initializer( - shape_object.as_list(), dtype=dtype, partition_info=partition_info) - variable_dtype = dtype.base_dtype - return resource_variable_ops.ResourceVariable( - initial_value=initial_value, - name=name, - dtype=variable_dtype, - **kwargs - ) - - -class Checkpointable(object): - """Manages variables and dependencies on other objects. - - To make reliable checkpoints, all `Checkpointable`s on which this object - depends must be registered in the constructor using `track_checkpointable` in - a deterministic order, and if possible they should be named. Variables may be - created using `add_variable` outside of the constructor and in any order, but - only these variables will be saved. - """ - - def __init__(self): - # A list of _CheckpointableReference objects. - self._checkpoint_dependencies = [] - # Maps names -> Checkpointable objects for named dependencies - self._dependency_names = {} - # Set of all tracked Checkpointables - self._already_tracked = set() - self._owned_variables = {} # local name -> variable object - self._deferred_restorations = {} # local name -> _VariableRestoration - # object - - def __setattr__(self, name, value): - """Support self.foo = checkpointable syntax. - - `self.foo = checkpointable` is equivalent to - `self.foo = self.track_checkpointable(checkpointable, name='foo')`. - - No new tracking if `value` is not a `Checkpointable`, or if `value` is - already being tracked (either because of an explicit `track_checkpointable` - or a previous `__setattr__`). - - Args: - name: The name of the property being set. - value: The new value for the property. - """ - # Give child classes (e.g. Network) priority, then track only if the object - # hasn't been added to _already_tracked. - super(Checkpointable, self).__setattr__(name, value) - if (isinstance(value, Checkpointable) - and value not in self._already_tracked): - self.track_checkpointable(value, name=name) - - def add_variable(self, name, shape=None, dtype=dtypes.float32, - initializer=None, **kwargs): - """Create a new variable object to be saved with this `Checkpointable`. - - If the user has requested that this object or another `Checkpointable` which - depends on this object be restored from a checkpoint (deferred loading - before variable object creation), `initializer` may be ignored and the value - from the checkpoint used instead. - - Args: - name: A name for the variable. Must be unique within this object. - shape: The shape of the variable. - dtype: The data type of the variable. - initializer: The initializer to use. Ignored if deferred loading has been - requested. - **kwargs: Passed to the ResourceVariable constructor. - - Returns: - The new variable object. - - Raises: - ValueError: If the variable name is not unique. - RuntimeError: If __init__ has not been called. - """ - if not hasattr(self, "_owned_variables"): - raise RuntimeError("Need to call Checkpointable.__init__ before adding " - "variables.") - if name in self._owned_variables: - raise ValueError( - ("A variable named '%s' already exists in this Checkpointable, but " - "Checkpointable.add_variable called to create another with " - "that name. Variable names must be unique within a Checkpointable " - "object.") % (name,)) - if "getter" in kwargs: - # Allow the getter to be overridden, typically because there is a need for - # compatibility with some other variable creation mechanism. This should - # be relatively uncommon in user code. - getter = kwargs.pop("getter") - else: - getter = _default_getter - deferred_restoration = self._deferred_restorations.pop(name, None) - if deferred_restoration is not None: - dtype = deferred_restoration.value_pointer.dtype - base_type = dtype.base_dtype - # TODO(allenl): Handle partitioned variables here too - with ops.init_scope(): - initializer, = io_ops.restore_v2( - prefix=deferred_restoration.value_pointer.save_path, - tensor_names=[deferred_restoration.value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_type], - name="checkpoint_initializer") - # We need to un-set the shape so get_variable doesn't complain, but we - # also need to set the static shape information on the initializer if - # possible so we don't get a variable with an unknown shape. - initializer.set_shape(shape) - # Un-set shape since we're using a constant initializer - shape = None - - new_variable = getter( - name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) - if deferred_restoration is not None: - if deferred_restoration.value_pointer.session is not None: - deferred_restoration.value_pointer.session.run(new_variable.initializer) - for slot_restoration in deferred_restoration.slot_restorations: - strong_ref = slot_restoration.optimizer_ref() - if strong_ref is None: - # If the optimizer object has been garbage collected, there's no need - # to create the slot variable. - continue - strong_ref._process_slot_restoration( # pylint: disable=protected-access - slot_restoration, new_variable) - self._owned_variables[name] = new_variable - return new_variable - - def track_checkpointable(self, checkpointable, name): - """Declare a dependency on another `Checkpointable` object. - - Indicates that checkpoints for this object should include variables from - `checkpointable`. - - Variables in a checkpoint are mapped to `Checkpointable`s based on names. To - avoid breaking existing checkpoints when modifying a class, neither variable - names nor dependency names (the names passed to `track_checkpointable`) may - change. - - Args: - checkpointable: A `Checkpointable` which this object depends on. - name: A local name for `checkpointable`, used for loading checkpoints into - the correct objects. Python 2 identifiers are valid names, with the - addition of leading numerals, periods anywhere, and non-leading dashes. - Specifically names must match the regular expression - `^[A-Za-z0-9_.][A-Za-z0-9_.-]*$`. - - Returns: - `checkpointable`, for convenience when declaring a dependency and - assigning to a member variable in one statement. - - Raises: - RuntimeError: If __init__ was not called. - TypeError: If `checkpointable` does not inherit from `Checkpointable`. - ValueError: For invalid names. - """ - if not hasattr(self, "_checkpoint_dependencies"): - raise RuntimeError("Need to call Checkpointable.__init__ before calling " - "Checkpointable.track_checkpointable().") - if not isinstance(checkpointable, Checkpointable): - raise TypeError( - ("Checkpointable.track_checkpointable() passed type %s, not a " - "Checkpointable.") % (type(checkpointable),)) - if not _VALID_LOCAL_NAME.match(name): - raise ValueError( - ("Checkpointable names must match the regular expression '%s', but " - "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, - name)) - if (name in self._dependency_names - and self._dependency_names[name] is not checkpointable): - raise ValueError( - ("Called Checkpointable.track_checkpointable() with name='%s', but " - "a Checkpointable with this name is already declared as a " - "dependency. Names must be unique.") % (name,)) - self._dependency_names[name] = checkpointable - self._checkpoint_dependencies.append( - _CheckpointableReference(name=name, ref=checkpointable)) - self._already_tracked.add(checkpointable) - return checkpointable - - def _process_restoration(self, restoration): - """Restore a variable and its slot variables (may be deferred).""" - variable_to_restore = self._owned_variables.get(restoration.name, None) - if variable_to_restore is not None: - # This variable already exists, so just do an assignment for this and any - # slot variables which depend on it. - _assign_existing_variable( - variable_to_restore, value_pointer=restoration.value_pointer) - for slot_restoration in restoration.slot_restorations: - strong_ref = slot_restoration.optimizer_ref() - if strong_ref is None: - continue - strong_ref._process_slot_restoration( # pylint: disable=protected-access - slot_restoration, variable_to_restore) - else: - # Save this restoration for later. This intentionally overwrites any - # previous deferred restorations, since that gives the same semantics as - # direct assignment. - self._deferred_restorations[restoration.name] = restoration - - def _process_slot_restoration(self, slot_restoration, variable): - """Restore a slot variable's value (creating it if necessary).""" - # TODO(allenl): Move this to Optimizer - assert isinstance(self, optimizer_lib.Optimizer) - named_slots = self._slot_dict(slot_restoration.slot_name) - variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access - existing_slot_variable = named_slots.get(variable_key, None) - if existing_slot_variable is None: - base_dtype = slot_restoration.value_pointer.dtype.base_dtype - initializer, = io_ops.restore_v2( - prefix=slot_restoration.value_pointer.save_path, - tensor_names=[slot_restoration.value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_dtype], - name="checkpoint_initializer") - new_slot_variable = slot_creator.create_slot(variable, initializer, - slot_restoration.slot_name) - if slot_restoration.value_pointer.session is not None: - slot_restoration.value_pointer.session.run( - new_slot_variable.initializer) - named_slots[variable_key] = new_slot_variable - else: - _assign_existing_variable( - existing_slot_variable, value_pointer=slot_restoration.value_pointer) - - @property - def checkpoint_dependencies(self): - """Other `Checkpointable` objects on which this object depends.""" - return self._checkpoint_dependencies - - -def _breadth_first_checkpointable_traversal(root_checkpointable): - """Find shortest paths to all variables owned by dependencies of root.""" - bfs_sorted = [] - root_checkpointable_reference = _CheckpointableReference( - name=None, ref=root_checkpointable) - to_visit = collections.deque([root_checkpointable_reference]) - path_to_root = {root_checkpointable_reference: ()} - while to_visit: - current_checkpointable = to_visit.popleft() - bfs_sorted.append(current_checkpointable) - for child_checkpointable in ( - current_checkpointable.ref.checkpoint_dependencies): - if child_checkpointable not in path_to_root: - path_to_root[child_checkpointable] = ( - path_to_root[current_checkpointable] + (child_checkpointable,)) - to_visit.append(child_checkpointable) - return bfs_sorted, path_to_root - - -def _object_prefix_from_path(path_to_root): - return "/".join( - (checkpointable.name for checkpointable in path_to_root)) - - -def _escape_variable_name(variable_name): - # We need to support slashes in variable names for compatibility, since this - # naming scheme is being patched in to things like Layer.add_variable where - # slashes were previously accepted. We also want to use slashes to indicate - # edges traversed to reach the variable, so we escape forward slashes in - # variable names. - return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") - - -def _variable_naming_for_object(path_to_root): - """Make a function for naming variables in an object.""" - # Name non-slot variables: - # - # / - # - # is not necessarily unique, but this is fine since we also - # save the graph of `Checkpointable`s with the checkpoint. Even if this path - # no longer exists because of a change in the Python program, we can look up - # the `Checkpointable` which owns the variable in the checkpoint's graph and - # use another path if one still exists. - - object_prefix = _object_prefix_from_path(path_to_root) - if object_prefix: - object_prefix += "/" - - def _name_single_variable(local_name): - """Names a variable within an object.""" - return object_prefix + _escape_variable_name(local_name) - - return _name_single_variable - - -def _slot_variable_naming_for_optimizer(optimizer, path_to_root): - """Make a function for naming slot variables in an optimizer.""" - # Name slot variables: - # - # /<_OPTIMIZER_SLOTS_NAME>// - # - # where is exactly the checkpoint name used for the original - # variable, including the path from the checkpoint root and the local name in - # the object which owns it. Note that we only save slot variables if the - # variable it's slotting for is also being saved. - - optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, - _object_prefix_from_path(path_to_root)) - - def _name_slot_variable(variable_path, slot_name): - """With an optimizer specified, name a slot variable.""" - - if not _VALID_LOCAL_NAME.match(slot_name): - # Slot variable names include the name of the slot. We need to - # validate that part of the name to be sure that the checkpoint name - # is a valid name scope name. - raise ValueError( - ("Could not save slot variables for optimizer %s, because its " - "slot name has invalid characters (got '%s', was expecting it " - "to match the regular expression '%s').") % - (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) - - return variable_path + optimizer_identifier + slot_name - - return _name_slot_variable - - -def _serialize_non_slot_variables(checkpointable_objects, path_to_root, - object_graph_proto): - """Name non-slot variables and add them to `object_graph_proto`.""" - named_variables = {} - non_slot_variables = [] - checkpoint_node_ids = {} - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - checkpoint_node_ids[checkpointable] = checkpoint_id - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) - object_proto = object_graph_proto.nodes.add() - for (local_name, owned_variable) in sorted( - checkpointable.ref._owned_variables.items(), # pylint: disable=protected-access - key=lambda x: x[0]): - variable_name = naming_scheme(local_name) - named_variables[variable_name] = owned_variable - non_slot_variables.append(( - variable_name, # The variable's full checkpoint name - owned_variable, # The variable object - local_name, # The variable's local name - checkpoint_id)) # The checkpoint ID of the node which owns this - # variable. - variable_proto = object_proto.variables.add() - variable_proto.local_name = local_name - variable_proto.checkpoint_key = variable_name - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [owned_variable], convert_variable_to_tensor=False) - variable_full_name, = saver_dict.keys() - variable_proto.full_name = variable_full_name - - for child in checkpointable.ref.checkpoint_dependencies: - child_proto = object_proto.children.add() - child_proto.node_id = checkpoint_node_ids[child] - child_proto.local_name = child.name - return named_variables, non_slot_variables - - -def _serialize_slot_variables(checkpointable_objects, path_to_root, - non_slot_variables, object_graph_proto): - """Name slot variables and add them to `object_graph_proto`.""" - named_slot_variables = {} - for optimizer_checkpoint_id, checkpointable_ref in enumerate( - checkpointable_objects): - if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): - optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] - naming_scheme = _slot_variable_naming_for_optimizer( - optimizer=checkpointable_ref.ref, - path_to_root=path_to_root[checkpointable_ref]) - slot_names = checkpointable_ref.ref.get_slot_names() - for (variable_path, original_variable, original_variable_local_name, - original_node_checkpoint_id) in non_slot_variables: - for slot_name in slot_names: - slot_variable = checkpointable_ref.ref.get_slot( - original_variable, slot_name) - if slot_variable is not None: - checkpoint_name = naming_scheme( - variable_path=variable_path, slot_name=slot_name) - named_slot_variables[checkpoint_name] = slot_variable - slot_variable_proto = optimizer_object_proto.slot_variables.add() - slot_variable_proto.slot_name = slot_name - slot_variable_proto.checkpoint_key = checkpoint_name - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [slot_variable], convert_variable_to_tensor=False) - slot_variable_full_name, = saver_dict.keys() - slot_variable_proto.full_name = slot_variable_full_name - slot_variable_proto.original_variable_local_name = ( - original_variable_local_name) - slot_variable_proto.original_variable_node_id = ( - original_node_checkpoint_id) - return named_slot_variables - - -# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct -# a root Checkpointable if passed a list of Checkpointables). -def _serialize_object_graph(root_checkpointable): - """Determine checkpoint keys for variables and build a serialized graph. - - Non-slot variables are keyed based on a shortest path from the root saveable - to the object which owns the variable (i.e. the one which called - `Checkpointable.add_variable` to create it). - - Slot variables are keyed based on a shortest path to the variable being - slotted for, a shortest path to their optimizer, and the slot name. - - Args: - root_checkpointable: A `Checkpointable` object whose variables (including - the variables of dependencies, recursively) should be saved. - - Returns: - A tuple of (named_variables, object_graph_proto): - named_variables: A dictionary mapping names to variable objects. - object_graph_proto: A CheckpointableObjectGraph protocol buffer containing - the serialized object graph and variable references. - - Raises: - ValueError: If there are invalid characters in an optimizer's slot names. - """ - checkpointable_objects, path_to_root = ( - _breadth_first_checkpointable_traversal(root_checkpointable)) - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - - # Gather non-slot variables. - named_variables, non_slot_variables = _serialize_non_slot_variables( - checkpointable_objects, path_to_root, object_graph_proto) - - # Gather slot variables which are associated with variables gathered above. - named_slot_variables = _serialize_slot_variables( - checkpointable_objects, path_to_root, non_slot_variables, - object_graph_proto) - - named_variables.update(named_slot_variables) - return named_variables, object_graph_proto - - -def _set_reference(reference_proto_table, key, checkpointable, parent, - object_id_map): - """Record a checkpoint<->object correspondence, with error checking. - - Args: - reference_proto_table: Map from names or numbers to `ObjectReference` protos - within the parent object. - key: Either a numeric or string identifier for the reference. - checkpointable: The object to record a correspondence for. - parent: The parent Python object, for creating a useful error message. - object_id_map: The map from `node_id` to Python object in which to record - the reference. - Returns: - The `node_id` of the Object proto corresponding to the specified Python - object. - Raises: - AssertionError: If another object is already bound to the `Object` proto. - """ - reference_proto = reference_proto_table[key] - set_reference = object_id_map.setdefault(reference_proto.node_id, - checkpointable) - if set_reference is not checkpointable: - raise AssertionError( - ("Unable to load the checkpoint into this object graph. Either " - "the Checkpointable object references in the Python program " - "have changed in an incompatible way, or the checkpoint was " - "generated in an incompatible program.\n\nTwo checkpoint " - "references (one being '%s' in %s) resolved to different " - "objects (%s and %s).") % (key, parent, set_reference, - checkpointable)) - return reference_proto.node_id - - -def _checkpoint_object_id_map(root_checkpointable, object_graph_proto): - """Match a checkpointed object graph to a Python object graph. - - Args: - root_checkpointable: A Checkpointable object. - object_graph_proto: A CheckpointableObjectGraph protocol buffer representing - a serialized object graph. - Returns: - A dictionary mapping from checkpoint node ids (indices into - `object_graph_proto.nodes`) to `Checkpointable` objects which are - dependencies of `root_checkpointable`. - """ - node_list = object_graph_proto.nodes - # Queue of (checkpointable object, node id) - to_visit = collections.deque([(root_checkpointable, 0)]) - object_id_map = {0: root_checkpointable} - seen = set() - while to_visit: - checkpointable, node_id = to_visit.popleft() - object_proto = node_list[node_id] - named_children = {} - for child_reference in object_proto.children: - if child_reference.local_name: - named_children[child_reference.local_name] = child_reference - else: - raise AssertionError( - ("The checkpointed object graph contains a reference without " - "a name (corrupted?). The reference was from the node %s.") - % (object_proto,)) - - for checkpointable_reference in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access - child_node_id = _set_reference( - reference_proto_table=named_children, - key=checkpointable_reference.name, - checkpointable=checkpointable_reference.ref, - parent=checkpointable, - object_id_map=object_id_map) - if child_node_id not in seen: - seen.add(child_node_id) - to_visit.append((checkpointable_reference.ref, child_node_id)) - - return object_id_map - - -_ValuePointer = collections.namedtuple( - "_ValuePointer", - [ - # Information needed to look up the value to restore. - "save_path", - "checkpoint_key", - "dtype", - # The session to use when restoring (None when executing eagerly) - "session", - ]) - -_SlotVariableRestoration = collections.namedtuple( - "_SlotVariableRestoration", - [ - # A weak reference to the Optimizer object - "optimizer_ref", - # The slot name - "slot_name", - # The _ValuePointer to use when restoring - "value_pointer", - ]) - -_VariableRestoration = collections.namedtuple( - "_VariableRestoration", - [ - # The variable's (local) name. - "name", - # _SlotVariableRestoration objects indicating slot variables which - # should be created once this variable has been restored. - "slot_restorations", - # The _ValuePointer to use when restoring - "value_pointer", - ]) - - -def _gather_restorations(object_graph_proto, save_path, object_id_map, - dtype_map, session): - """Iterate over variables to restore, matching with Checkpointable objects.""" - variable_to_slot_restorations = {} - for node_id, node in enumerate(object_graph_proto.nodes): - for slot_variable in node.slot_variables: - original_variable_key = (slot_variable.original_variable_node_id, - slot_variable.original_variable_local_name) - variable_to_slot_restorations.setdefault( - original_variable_key, []).append( - _SlotVariableRestoration( - optimizer_ref=weakref.ref(object_id_map[node_id]), - slot_name=slot_variable.slot_name, - value_pointer=_ValuePointer( - save_path=save_path, - checkpoint_key=slot_variable.checkpoint_key, - dtype=dtype_map[slot_variable.checkpoint_key], - session=session))) - - for node_id, node in enumerate(object_graph_proto.nodes): - for variable in node.variables: - slots_key = (node_id, variable.local_name) - variable_restore = _VariableRestoration( - name=variable.local_name, - slot_restorations=variable_to_slot_restorations.get(slots_key, []), - value_pointer=_ValuePointer( - save_path=save_path, - checkpoint_key=variable.checkpoint_key, - dtype=dtype_map[variable.checkpoint_key], - session=session)) - yield variable_restore, object_id_map[node_id] - - -def save(file_prefix, root_checkpointable, global_step=None, session=None): - """Save a training checkpoint. - - Args: - file_prefix: A prefix to use for the checkpoint filenames - (/path/to/directory/and_a_prefix). Names are generated based on this - prefix and the global step, if provided. - root_checkpointable: A Checkpointable object to save. The checkpoint - includes variables created by this object and any Checkpointable objects - it depends on. - global_step: An integer variable or Tensor, used to number - checkpoints. Typically this value is saved along with other variables in - training checkpoints, which will happen automatically if it was created by - `root_checkpointable` or one of its dependencies (via - `Checkpointable.add_variable`). - session: The session to evaluate variables in. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - - Returns: - The full path to the checkpoint. - - Currently also returns the serialized object graph proto, but that will go - away once it's saved with the checkpoint. - """ - named_variables, serialized_graph = _serialize_object_graph( - root_checkpointable) - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - with ops.device("/device:CPU:0"): - save_path = saver_lib.Saver(var_list=named_variables).save( - sess=session, - save_path=file_prefix, - write_meta_graph=False, - global_step=global_step) - # TODO(allenl): Save the graph with the checkpoint, then returning it and - # taking it as an argument to restore won't be necessary. - return serialized_graph, save_path - - -# NOTE: Will be restore(file_prefix, root_checkpointable) once the object graph -# is saved with the checkpoint. -def restore(save_path, root_checkpointable, object_graph_proto, session=None): - """Restore a training checkpoint. - - Restores the values of variables created with `Checkpointable.add_variable` in - the dependency graph of `root_checkpointable`. Either assigns values - immediately (if variables to restore have been created already), or defers - restoration until the variables are created. - - When building a graph, restorations are executed in the default session if - `session` is `None`. Variable initializers read checkpointed values. - - Args: - save_path: The path to the checkpoint, as returned by `save` or - `tf.train.latest_checkpoint`. If None (as when there is no latest - checkpoint for `tf.train.latest_checkpoint` to return), does nothing. - root_checkpointable: The root of the object graph to restore. Variables to - restore need not have been created yet, but all dependencies on other - Checkpointable objects should already be declared. Objects in the - dependency graph are matched to objects in the checkpointed graph, and - matching objects have their variables restored (or the checkpointed values - saved for eventual restoration when the variable is created). - object_graph_proto: (Temporary) the checkpointed object graph. This will - eventually be saved with the checkpoint, and will not be part of the final - API. - session: The session to evaluate assignment ops in. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - """ - if save_path is None: - return - object_id_map = _checkpoint_object_id_map(root_checkpointable, - object_graph_proto) - reader = training.NewCheckpointReader(save_path) - dtype_map = reader.get_variable_to_dtype_map() - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - for restoration, checkpointable in _gather_restorations( - object_graph_proto, save_path, object_id_map, dtype_map, session=session): - checkpointable._process_restoration(restoration) # pylint: disable=protected-access - diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py deleted file mode 100644 index f7bc155decbb574ddd4b53190da3c3b3ee9b6a4e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import os - -import six - -from tensorflow.contrib.eager.python import checkpointable -from tensorflow.contrib.eager.python import network as network_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.layers import base -from tensorflow.python.layers import core -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.training import adam -from tensorflow.python.training import saver as core_saver -from tensorflow.python.training import training_util - - -class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - core.Dense.__init__(self, *args, **kwargs) - - def add_variable(self, name, shape, **kwargs): - # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually - # Layer.add_variable should inherit from Checkpointable and simply call - # super and then do post-processing. - return checkpointable.Checkpointable.add_variable( - self, - name=name, - shape=shape, - getter=functools.partial(core.Dense.add_variable, self), - **kwargs) - - -# pylint: disable=not-callable -class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): - - def __init__(self): - network_lib.Network.__init__(self) - checkpointable.Checkpointable.__init__(self) - - def __setattr__(self, name, value): - if isinstance(value, base.Layer) and value not in self._already_tracked: - self.track_layer(value, name=name) - # Checkpointable is next in the method resolution order, so this will catch - # Checkpointable objects which aren't Layers. - super(CheckpointableNetwork, self).__setattr__(name, value) - - def track_layer(self, layer, name): - self.track_checkpointable(layer, name=name) - return super(CheckpointableNetwork, self).track_layer(layer) - - -class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - adam.AdamOptimizer.__init__(self, *args, **kwargs) - - # NOTE: Copied from Optimizer with modifications to use add_variable - # for non-slot variables. These contortions are necessary to maintain - # checkpoint compatibility with variable.name based saving. - # TODO(allenl): Make this cleaner. - def _create_non_slot_variable(self, initial_value, name, colocate_with): - """Add an extra variable, not associated with a slot.""" - if context.in_graph_mode(): - graph = colocate_with.graph - else: - graph = None - - key = (name, graph) - v = self._non_slot_dict.get(key, None) - if v is None: - with ops.colocate_with(colocate_with): - def _variable_getter(name, shape, dtype, initializer): - del shape, dtype # not used, but there for compatibility - return variable_scope.variable( - name=name, initial_value=initializer, trainable=False) - - initial_value = ops.convert_to_tensor(initial_value) - v = self.add_variable( - name=name, - shape=initial_value.get_shape(), - initializer=initial_value, - getter=_variable_getter) - - self._non_slot_dict[key] = v - - return v - - -class NonLayerCheckpointable(checkpointable.Checkpointable): - - def __init__(self): - super(NonLayerCheckpointable, self).__init__() - self.a_variable = self.add_variable(name="a_variable", shape=[]) - - -class MyNetwork(CheckpointableNetwork): - """A concrete Network for testing.""" - - def __init__(self): - super(MyNetwork, self).__init__() - self._named_dense = CheckpointableDenseLayer(1, use_bias=True) - self._via_track_layer = self.track_layer( - CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() - - def call(self, values): - return self._via_track_layer(self._named_dense(values)) - - -class Root(checkpointable.Checkpointable): - """A stand-in for a Trainer class.""" - - def __init__(self, optimizer, network): - super(Root, self).__init__() - self._optimizer = optimizer - self._network = self.track_checkpointable(network, "network") - self._global_step = None - - @property - def global_step(self): - if self._global_step is None: - # Get the default create_global_step utility to actually call - # self.add_variable, by setting a custom creator. - def _owned_variable_as_creator( - next_creator, initial_value, **kwargs): - def _creator_as_getter(initializer, **kwargs): - return next_creator(initial_value=initializer, **kwargs) - return self.add_variable( - getter=_creator_as_getter, initializer=initial_value, shape=[], - **kwargs) - - with variable_scope.variable_creator_scope( - _owned_variable_as_creator): - self._global_step = training_util.create_global_step() - return self._global_step - - -class InterfaceTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testAddVariable(self): - obj = NonLayerCheckpointable() - with self.assertRaisesRegexp(ValueError, "do not specify shape"): - obj.add_variable( - name="shape_specified_twice", shape=[], initializer=1) - constant_initializer = obj.add_variable( - name="constant_initializer", initializer=1) - with variable_scope.variable_scope("some_variable_scope"): - ones_initializer = obj.add_variable( - name="ones_initializer", - shape=[2], - initializer=init_ops.ones_initializer(dtype=dtypes.float32)) - bare_initializer = obj.add_variable( - name="bare_initializer", - shape=[2, 2], - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) - - # Even in graph mode, there are no naming conflicts between objects, only - # naming conflicts within an object. - other_duplicate = resource_variable_ops.ResourceVariable( - name="duplicate", initial_value=1.) - duplicate = obj.add_variable(name="duplicate", shape=[]) - with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): - obj.add_variable(name="duplicate", shape=[]) - - if context.in_graph_mode(): - self.evaluate(variables.global_variables_initializer()) - self.assertEqual("constant_initializer:0", constant_initializer.name) - self.assertEqual(1, self.evaluate(constant_initializer)) - self.assertEqual("some_variable_scope/ones_initializer:0", - ones_initializer.name) - self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) - self.assertAllEqual([[0., 0.], - [0., 0.]], self.evaluate(bare_initializer)) - self.assertEqual("a_variable:0", obj.a_variable.name) - self.assertEqual("duplicate:0", other_duplicate.name) - if context.in_graph_mode(): - # The .name attribute may be globally influenced, but the checkpoint name - # won't be (tested below). - self.assertEqual("duplicate_1:0", duplicate.name) - else: - # When executing eagerly, there's no uniquification of variable names. The - # checkpoint name will be the same. - self.assertEqual("duplicate:0", duplicate.name) - named_variables, _ = checkpointable._serialize_object_graph(obj) - expected_checkpoint_names = ( - "a_variable", - "bare_initializer", - "constant_initializer", - "duplicate", - "ones_initializer", - ) - six.assertCountEqual( - self, expected_checkpoint_names, named_variables.keys()) - - def testInitNotCalled(self): - - class NoInit(checkpointable.Checkpointable): - - def __init__(self): - pass - - with self.assertRaisesRegexp(RuntimeError, "__init__"): - NoInit().add_variable("var", shape=[]) - - -class CheckpointingTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNamingWithOptimizer(self): - input_value = constant_op.constant([[3.]]) - network = MyNetwork() - # A nuisance Network using the same optimizer. Its slot variables should not - # go in the checkpoint, since it is never depended on. - other_network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Root(optimizer=optimizer, network=network) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value), - global_step=root_checkpointable.global_step) - optimizer.minimize( - lambda: other_network(input_value), - global_step=root_checkpointable.global_step) - else: - train_op = optimizer.minimize( - network(input_value), global_step=root_checkpointable.global_step) - optimizer.minimize( - other_network(input_value), - global_step=root_checkpointable.global_step) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - named_variables, serialized_graph = checkpointable._serialize_object_graph( - root_checkpointable) - expected_checkpoint_names = ( - # Created in the root node, so no prefix. - "global_step", - # No name provided to track_checkpointable(), so the position is used - # instead (one-based). - "network/via_track_layer/kernel", - # track_checkpointable() with a name provided, so that's used - "network/_named_dense/kernel", - "network/_named_dense/bias", - # non-Layer dependency of the network - "network/_non_layer/a_variable", - # The optimizer creates two non-slot variables - "_optimizer/beta1_power", - "_optimizer/beta2_power", - # Slot variables - "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/m", - "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/v", - "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m", - "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v", - "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", - "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v", - ) - six.assertCountEqual(self, expected_checkpoint_names, - named_variables.keys()) - # Check that we've mapped to the right variable objects (not exhaustive) - self.assertEqual("global_step:0", named_variables["global_step"].name) - self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", - named_variables["network/via_track_layer/kernel"].name) - self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", - named_variables["network/_named_dense/kernel"].name) - self.assertEqual("beta1_power:0", - named_variables["_optimizer/beta1_power"].name) - self.assertEqual("beta2_power:0", - named_variables["_optimizer/beta2_power"].name) - # Spot check the generated protocol buffers. - self.assertEqual("_optimizer", - serialized_graph.nodes[0].children[0].local_name) - optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ - 0].node_id] - self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) - self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) - # Variable ordering is arbitrary but deterministic (alphabetized) - self.assertEqual( - "bias", optimizer_node.slot_variables[0].original_variable_local_name) - original_variable_owner = serialized_graph.nodes[ - optimizer_node.slot_variables[0].original_variable_node_id] - self.assertEqual("network/_named_dense/bias", - original_variable_owner.variables[0].checkpoint_key) - self.assertEqual("bias", original_variable_owner.variables[0].local_name) - self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) - self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", - optimizer_node.slot_variables[0].checkpoint_key) - # We strip off the :0 suffix, as variable.name-based saving does. - self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam", - optimizer_node.slot_variables[0].full_name) - self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0", - optimizer.get_slot( - var=named_variables["network/_named_dense/bias"], - name="m").name) - - @test_util.run_in_graph_and_eager_modes() - def testSaveRestore(self): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Root(optimizer=optimizer, network=network) - input_value = constant_op.constant([[3.]]) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value), - global_step=root_checkpointable.global_step) - else: - train_op = optimizer.minimize( - network(input_value), global_step=root_checkpointable.global_step) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) - m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") - self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - serialized_graph, save_path = checkpointable.save( - file_prefix=prefix, - root_checkpointable=root_checkpointable, - global_step=root_checkpointable.global_step) - self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) - optimizer_variables = self.evaluate(optimizer.variables()) - self.evaluate(state_ops.assign(m_bias_slot, [-2.])) - # Immediate restoration - checkpointable.restore( - save_path=save_path, - root_checkpointable=root_checkpointable, - object_graph_proto=serialized_graph) - self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) - self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) - with ops.Graph().as_default(): - on_create_network = MyNetwork() - on_create_optimizer = CheckpointableAdam(0.001) - on_create_root = Root( - optimizer=on_create_optimizer, network=on_create_network) - with self.test_session(graph=ops.get_default_graph()): - # Deferred restoration - checkpointable.restore( - save_path=save_path, - root_checkpointable=on_create_root, - object_graph_proto=serialized_graph) - on_create_network(constant_op.constant([[3.]])) # create variables - self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) - self.assertAllEqual([42.], - self.evaluate( - on_create_network._named_dense.variables[1])) - on_create_m_bias_slot = on_create_optimizer.get_slot( - on_create_network._named_dense.variables[1], "m") - # Optimizer slot variables are created when the original variable is - # restored. - self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) - # beta1_power and beta2_power haven't been created yet, but everything - # else matches. - self.assertAllEqual(optimizer_variables[2:], - self.evaluate(on_create_optimizer.variables())) - on_create_optimizer._create_slots( - [resource_variable_ops.ResourceVariable([1.])]) - beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() - self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) - self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) - - def testDeferredRestorationUsageEager(self): - """An idiomatic eager execution example.""" - num_training_steps = 10 - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - latest_object_graph = None # Will be saved with the checkpoint eventually. - for training_continuation in range(3): - with ops.Graph().as_default(): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Root(optimizer=optimizer, network=network) - checkpointable.restore( - save_path=core_saver.latest_checkpoint(checkpoint_directory), - root_checkpointable=root, - object_graph_proto=latest_object_graph) - for _ in range(num_training_steps): - # TODO(allenl): Use a Dataset and serialize/checkpoint it. - input_value = constant_op.constant([[3.]]) - optimizer.minimize( - lambda: network(input_value), # pylint: disable=cell-var-from-loop - global_step=root.global_step) - latest_object_graph, _ = checkpointable.save( - file_prefix=checkpoint_prefix, - root_checkpointable=root) - self.assertEqual((training_continuation + 1) * num_training_steps, - root.global_step.numpy()) - - def testUsageGraph(self): - """Expected usage when graph building.""" - with context.graph_mode(): - num_training_steps = 10 - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - latest_object_graph = None - for training_continuation in range(3): - with ops.Graph().as_default(): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Root(optimizer=optimizer, network=network) - input_value = constant_op.constant([[3.]]) - train_op = optimizer.minimize( - network(input_value), - global_step=root.global_step) - init_op = variables.global_variables_initializer() - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) - with self.test_session(graph=ops.get_default_graph()) as session: - if checkpoint_path is None: - self.assertEqual(0, training_continuation) - session.run(init_op) - # Another alternative would be to run initializers automatically - # if no checkpoint is being loaded. This would make deferred - # loading a bit more useful with graph execution. - else: - checkpointable.restore( - save_path=checkpoint_path, - root_checkpointable=root, - object_graph_proto=latest_object_graph, - session=session) - for _ in range(num_training_steps): - session.run(train_op) - latest_object_graph, _ = checkpointable.save( - file_prefix=checkpoint_prefix, - root_checkpointable=root, - session=session) - self.assertEqual((training_continuation + 1) * num_training_steps, - session.run(root.global_step)) - - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testVariableNameEscaping(self): - self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) - self.assertEqual(r"b", self._get_checkpoint_name(r"b")) - self.assertEqual(r"c_S__", self._get_checkpoint_name(r"c/")) - self.assertEqual(r"d_S___S_._", self._get_checkpoint_name(r"d/_S__")) - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - root.track_checkpointable(leaf, name="leaf") - leaf.add_variable(name="v", shape=[]) - named_variables, _ = checkpointable._serialize_object_graph(root) - variable_name, = named_variables.keys() - self.assertEqual(r"leaf/v", variable_name) - - @test_util.run_in_graph_and_eager_modes() - def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - with self.assertRaisesRegexp(ValueError, "invalid name"): - # Leading dashes are reserved, which avoids conflicts with un-named edges - # in paths and the optimizer slots identifier. - root.track_checkpointable(leaf, name="-unnamed-12") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..30c4103c5aa52a74bcc8f72c7e1df186c9f7f591 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -0,0 +1,136 @@ +"""Utilities for working with Checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import checkpointable as core_checkpointable +from tensorflow.python.training import saver as saver_lib + + +class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + """Wraps save and restore callbacks as a `SaveableObject`.""" + + def __init__(self, name, dtype, save_callback, restore_callback): + self._restore_callback = restore_callback + spec = saver_lib.BaseSaverBuilder.SaveSpec( + tensor=save_callback, + slice_spec="", + name=name, + dtype=dtype) + super(_CallbackSaveable, self).__init__( + save_callback, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return self._restore_callback(tensor) + + +class _SplitDependency(core_checkpointable.CheckpointableBase): + """Looks like a regular variable while synchronizing save/restores.""" + + def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, + fill_save_buffer_fn, consume_restore_buffer_fn): + self._save_buffer = save_buffer + self._restore_buffer = restore_buffer + self._name = name + self._dtype = dtype + self._num_components = num_components + self._fill_save_buffer_fn = fill_save_buffer_fn + self._consume_restore_buffer_fn = consume_restore_buffer_fn + + def _save(self): + """Pull from the shared buffer, populating it if necessary.""" + if self._name not in self._save_buffer: + if self._save_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be saved together.") % (self._name, self)) + self._fill_save_buffer_fn(self._save_buffer) + return self._save_buffer.pop(self._name) + + def _restore(self, tensor): + """Push into the shared buffer, flushing it if necessary.""" + if self._name in self._restore_buffer: + raise AssertionError( + ("Split dependency %s (%s) unsynchronized. Split dependencies must " + "be restored together.") % (self._name, self)) + self._restore_buffer[self._name] = tensor + if len(self._restore_buffer) == self._num_components: + op = self._consume_restore_buffer_fn(self._restore_buffer) + self._restore_buffer.clear() + return op + else: + return control_flow_ops.no_op() + + def _gather_saveables_for_checkpoint(self): + """Looks to Checkpointable like a regular variable.""" + return { + core_checkpointable.VARIABLE_VALUE_KEY: + functools.partial(_CallbackSaveable, + dtype=self._dtype, + save_callback=self._save, + restore_callback=self._restore) + } + + +def split_dependency(component_names, component_dtypes, + fill_save_buffer_fn, consume_restore_buffer_fn): + """Creates multiple dependencies with a synchronized save/restore. + + Useful when a single op produces `Tensor`s which should each be saved under + different objects, or when `Tensor`s saved with many different objects need to + be restored together as inputs to a single op (i.e. an object which uses a + single fused op may be swapped out for a subgraph of objects, and these two + programs are checkpoint compatible). + + Args: + component_names: A sequence of names for the split + dependencies. `fill_save_buffer_fn` must add these keys to the dictionary + it is passed, and `consume_restore_buffer_fn` will receive a dictionary + with these keys. + component_dtypes: Data types for the `Tensor`s being saved and restored, a + sequence corresponding to `component_names`. + fill_save_buffer_fn: A function which takes an empty dictionary as an + argument and adds `Tensor`s with `component_names` as keys. These + `Tensor`s will be saved as if they were individual variables. + consume_restore_buffer_fn: A function which takes a dictionary with + `component_names` as keys mapping to restored individual `Tensor`s and + returns a restore op (or if executing eagerly, runs the restoration and + may return `None`). + + Returns: + A dictionary mapping from names to Checkpointable objects. If one is + reachable from an object as a dependency, the others should be too; adding + dependencies on some but not all of the objects will result in errors. + """ + save_buffer = {} + restore_buffer = {} + split_dependencies = {} + for name, dtype in zip(component_names, component_dtypes): + split_dependencies[name] = _SplitDependency( + save_buffer=save_buffer, + restore_buffer=restore_buffer, + name=name, + dtype=dtype, + num_components=len(component_names), + fill_save_buffer_fn=fill_save_buffer_fn, + consume_restore_buffer_fn=consume_restore_buffer_fn) + return split_dependencies diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bd42d405db9d1275c83636dc83090fa11b0b74b1 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.eager.python import checkpointable_utils as contrib_checkpointable_utils +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import checkpointable_utils + + +def _split_variable_closure(variable): + def _fill_save_buffer_fn(save_buffer): + save_buffer["first_half"] = variable[:2] + save_buffer["second_half"] = variable[2:] + return _fill_save_buffer_fn + + +def _combine_variable_closure(variable): + def _consume_restore_buffer_fn(restore_buffer): + return variable.assign( + array_ops.concat([restore_buffer["first_half"], + restore_buffer["second_half"]], + axis=0)) + return _consume_restore_buffer_fn + + +class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): + + def __init__(self): + self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) + split_dependencies = contrib_checkpointable_utils.split_dependency( + component_names=("first_half", "second_half"), + component_dtypes=(self.combined.dtype,) * 2, + fill_save_buffer_fn=_split_variable_closure( + self.combined), + consume_restore_buffer_fn=_combine_variable_closure( + self.combined)) + for name, dep in split_dependencies.items(): + self._track_checkpointable(dep, name=name) + + +class HasRegularDeps(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + self.second_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class OnlyOneDep(checkpointable.Checkpointable): + + def __init__(self): + self.first_half = resource_variable_ops.ResourceVariable([0., 0.]) + + +class SplitTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestoreSplitDep(self): + save_checkpoint = checkpointable_utils.Checkpoint( + dep=SaveTensorSlicesAsDeps()) + self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_checkpoint.save(checkpoint_prefix) + + regular_deps = HasRegularDeps() + regular_restore_checkpoint = checkpointable_utils.Checkpoint( + dep=regular_deps) + regular_restore_checkpoint.restore( + save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(regular_deps.first_half)) + self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half)) + + one_dep = OnlyOneDep() + one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep) + status = one_dep_restore_checkpoint.restore(save_path) + with self.assertRaises(AssertionError): + # Missing the second dependency. + status.assert_consumed() + status.run_restore_ops() + self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half)) + + restore_checkpoint = checkpointable_utils.Checkpoint() + status = restore_checkpoint.restore(save_path) + restore_checkpoint.dep = SaveTensorSlicesAsDeps() + status.assert_consumed().run_restore_ops() + self.assertAllEqual( + [1., 2., 3., 4.], + self.evaluate(restore_checkpoint.dep.combined)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index d177bfeab2d1fdc05d7ced54df8723fae2c77fdb..0783d1b5d70e502e6edd80b59f37fdd93b413e12 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -27,11 +27,12 @@ from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training.saver import BaseSaverBuilder _uid_counter = 0 _uid_lock = threading.Lock() @@ -45,8 +46,13 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) -class Iterator(object): - """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" +class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): + """An iterator producing tf.Tensor objects from a tf.data.Dataset. + + NOTE: Unlike the iterator created by the + @{tf.data.Dataset.make_one_shot_iterator} method, this class enables + additional experimental functionality, such as prefetching to the GPU. + """ def __init__(self, dataset): """Creates a new iterator over the given dataset. @@ -65,39 +71,21 @@ class Iterator(object): dataset: A `tf.data.Dataset` object. Raises: + TypeError: If `dataset` is an unsupported type. RuntimeError: When invoked without eager execution enabled. """ + if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access + raise TypeError( + "`tf.contrib.data.prefetch_to_device()` is not compatible with " + "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate " + "over the dataset instead.") - if not context.in_eager_mode(): - raise RuntimeError( - "{} objects can only be used when eager execution is enabled, use " - "tf.data.Dataset.make_iterator or " - "tf.data.Dataset.make_one_shot_iterator for graph construction". - format(type(self))) - with ops.device("/device:CPU:0"): - ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access - self._output_classes = dataset.output_classes - self._output_types = dataset.output_types - self._output_shapes = dataset.output_shapes - self._flat_output_types = nest.flatten( - sparse.as_dense_types(self._output_types, self._output_classes)) - self._flat_output_shapes = nest.flatten( - sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.iterator( - shared_name="", - container=_generate_shared_name("eageriterator"), - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device="/device:CPU:0") - self._device = context.context().device_name - self._buffer_resource_handle = None + super(Iterator, self).__init__(dataset) if not context.context().device_spec.device_type: is_remote_device = False else: is_remote_device = context.context().device_spec.device_type != "CPU" + self._buffer_resource_handle = None if is_remote_device: with ops.device("/device:CPU:0"): iter_string_handle = gen_dataset_ops.iterator_to_string_handle( @@ -106,7 +94,7 @@ class Iterator(object): @function.Defun(dtypes.string) def remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( - h, self._output_types, self._output_shapes) + h, self.output_types, self.output_shapes, self.output_classes) return remote_iterator.get_next() remote_fn.add_to_graph(None) @@ -117,96 +105,53 @@ class Iterator(object): f=remote_fn, target_device=target, buffer_size=10, - thread_pool_size=1, container="", shared_name=_generate_shared_name("function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long handle=self._buffer_resource_handle, handle_device=self._device) - def __iter__(self): - return self - - def __next__(self): # For Python 3 compatibility - return self.next() - def _next_internal(self): """Returns a nested structure of `tf.Tensor`s containing the next element. """ - with ops.device(self._device): + # This runs in sync mode as iterators use an error status to communicate + # that there is no more data to iterate over. + # TODO(b/77291417): Fix + with context.execution_mode(context.SYNC): if self._buffer_resource_handle is not None: - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) + with ops.device(self._device): + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) else: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` - # because in eager mode this code will run synchronously on the calling - # thread. Therefore we do not need to make a defensive context switch - # to a background thread, and can achieve a small constant performance - # boost by invoking the iterator synchronously. - ret = gen_dataset_ops.iterator_get_next_sync( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - - return sparse.deserialize_sparse_tensors( - nest.pack_sequence_as(self._output_types, ret), self._output_types, - self._output_shapes, self._output_classes) - - def next(self): - """Returns a nested structure of `tf.Tensor`s containing the next element. - """ - try: - return self._next_internal() - except errors.OutOfRangeError: - raise StopIteration - - @property - def output_classes(self): - """Returns the class of each component of an element of this iterator. - - The expected values are `tf.Tensor` and `tf.SparseTensor`. + return super(Iterator, self)._next_internal() - Returns: - A nested structure of Python `type` objects corresponding to each - component of an element of this dataset. - """ - return self._output_classes + # TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset + # attributes(potential). - @property - def output_shapes(self): - """Returns the shape of each component of an element of this iterator. + class _Saveable(BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" - Returns: - A nested structure of `tf.TensorShape` objects corresponding to each - component of an element of this dataset. - """ - return self._output_shapes + def __init__(self, iterator_resource, name): + serialized_iterator = gen_dataset_ops.serialize_iterator( + iterator_resource) + specs = [ + BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE") + ] + # pylint: disable=protected-access + super(Iterator._Saveable, self).__init__(iterator_resource, specs, name) - @property - def output_types(self): - """Returns the type of each component of an element of this iterator. + def restore(self, restored_tensors, restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, + restored_tensors[0]) - Returns: - A nested structure of `tf.DType` objects corresponding to each component - of an element of this dataset. - """ - return self._output_types + def _gather_saveables_for_checkpoint(self): - def get_next(self, name=None): - """Returns a nested structure of `tf.Tensor`s containing the next element. + def _saveable_factory(name): + return self._Saveable(self._resource, name) - Args: - name: (Optional.) A name for the created operation. Currently unused. - - Returns: - A nested structure of `tf.Tensor` objects. - - Raises: - `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. - """ - del name - return self._next_internal() + return {"ITERATOR": _saveable_factory} diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a1611e92b113839c2dd2a3b2560b0ba90c0a7ef0..7b123707cc3a26073088cf2c57c6211e831c19fd 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,11 +16,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +import threading import time import numpy as np from tensorflow.contrib import lookup +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.contrib.data.python.ops import threadpool +from tensorflow.contrib.data.python.ops import unique from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test @@ -31,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.training import checkpointable_utils class IteratorTest(test.TestCase): @@ -41,6 +48,18 @@ class IteratorTest(test.TestCase): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) + def testBasicOneShotIterator(self): + got = [] + for t in Dataset.range(4).make_one_shot_iterator(): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + + def testBasicImplicitIterator(self): + got = [] + for t in Dataset.range(4): + got.append(t.numpy()) + self.assertAllEqual([0, 1, 2, 3], got) + def testGetNext(self): iterator = datasets.Iterator(Dataset.range(4)) self.assertEqual(0, iterator.get_next().numpy()) @@ -50,6 +69,15 @@ class IteratorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): iterator.get_next() + def testGetNextOneShotIterator(self): + iterator = Dataset.range(4).make_one_shot_iterator() + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + self.assertEqual(2, iterator.get_next().numpy()) + self.assertEqual(3, iterator.get_next().numpy()) + with self.assertRaises(errors.OutOfRangeError): + iterator.get_next() + def testMultipleIteratorsOnTheSameDataset(self): ds = Dataset.range(4) it1 = datasets.Iterator(ds) @@ -165,6 +193,105 @@ class IteratorTest(test.TestCase): x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) + def testTensorsExplicitPrefetchToDevice(self): + ds = Dataset.from_tensor_slices([0., 1.]) + ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name())) + + with self.assertRaisesRegexp(TypeError, 'prefetch_to_device'): + datasets.Iterator(ds) + + for i, x in enumerate(ds): + with ops.device(test.gpu_device_name()): + x = math_ops.add(x, x) + self.assertEqual(float(i) + float(i), x.numpy()) + + def testOverrideThreadPool(self): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + for num_threads in [1, 2, 4, 8, 16]: + + dataset = ( + Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, display_name='private_thread_pool_%d' % num_threads)) + + thread_ids = [] + for next_element in datasets.Iterator(dataset): + thread_ids.append(next_element) + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator = datasets.Iterator(dataset) + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertAllEqual([1, 4], iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator.get_next().numpy()) + self.assertAllEqual([25, 36], iterator.get_next().numpy()) + + def testSaveRestoreMultipleIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset.map(math_ops.square).batch(2) + iterator_1 = datasets.Iterator(dataset) + iterator_2 = datasets.Iterator(dataset) + dataset_2 = Dataset.range(10) + iterator_3 = datasets.Iterator(dataset_2) + + checkpoint = checkpointable_utils.Checkpoint( + iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) + self.assertAllEqual([1, 4], iterator_1.get_next().numpy()) + self.assertEqual(0, iterator_3.get_next().numpy()) + self.assertEqual(1, iterator_3.get_next().numpy()) + self.assertEqual(2, iterator_3.get_next().numpy()) + + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertAllEqual([9, 16], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + checkpoint.restore(save_path) + self.assertAllEqual([9, 16], iterator_1.get_next().numpy()) + self.assertAllEqual([1, 4], iterator_2.get_next().numpy()) + self.assertEqual(3, iterator_3.get_next().numpy()) + + def testRestoreExhaustedIterator(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + dataset = Dataset.range(3) + iterator = datasets.Iterator(dataset) + + checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertEqual(2, iterator.get_next().numpy()) + checkpoint.restore(save_path) + self.assertEqual(2, iterator.get_next().numpy()) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index 68e7b5421fec7f73f10e381ca45f9d900de299d7..7949a3f6da293abdd85512209242bae76ab4d816 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -22,12 +22,12 @@ import six from tensorflow.contrib.eager.python import datasets from tensorflow.contrib.eager.python import metrics -from tensorflow.contrib.summary import summary_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops class Evaluator(object): @@ -57,7 +57,7 @@ class Evaluator(object): self._model = model self._metrics = {} self._evaluators = {} - if context.in_graph_mode(): + if not context.executing_eagerly(): self.call = function.defun(self.call) # ---- API for users ---- @@ -90,7 +90,7 @@ class Evaluator(object): Only for graph execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Evaluator.init_variables() not needed when " "eager execution is enabled.") return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) @@ -113,7 +113,8 @@ class Evaluator(object): with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() - if context.in_eager_mode(): + + if context.executing_eagerly(): return f() else: return function.defun(f)() @@ -158,16 +159,16 @@ class Evaluator(object): @end_compatibility """ summary_logdir = kwargs.pop("summary_logdir", None) - if context.in_graph_mode(): - call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), - *args, **kwargs) - init_op = self.init_variables() - results_op = self.all_metric_results(summary_logdir) - return (init_op, call_op, results_op) - # Eager case - for example in datasets.Iterator(dataset): - self.__call__(example, *args, **kwargs) - return self.all_metric_results(summary_logdir) + if context.executing_eagerly(): + for example in datasets.Iterator(dataset): + self.__call__(example, *args, **kwargs) + return self.all_metric_results(summary_logdir) + # Graph construction + call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), *args, + **kwargs) + init_op = self.init_variables() + results_op = self.all_metric_results(summary_logdir) + return (init_op, call_op, results_op) @staticmethod def run_evaluation(init_op, call_op, results_op, sess=None): @@ -192,7 +193,7 @@ class Evaluator(object): Only for graph execution. @end_compatibility """ - if context.in_eager_mode(): + if context.executing_eagerly(): raise RuntimeError("Evaluator.run_evaluation() not supported when " "eager execution is enabled.") sess = sess or ops.get_default_session() diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 15a21885f66eface291a39fa0ee1ff28bc297548..c1fd9e0ed020beeb722204edf1adfe1dfcf8ff03 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -8,7 +8,6 @@ py_library( deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", "//tensorflow/contrib/eager/python/examples/linear_regression", - "//tensorflow/contrib/eager/python/examples/mnist", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b9ac79f46c83bb709918e3b72830b90ddcfd71b4..b80c90902353709b7f739585291ec3b5890c27c7 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -32,10 +32,11 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe from tensorflow.examples.tutorials.mnist import input_data +layers = tf.keras.layers FLAGS = None -class Discriminator(tfe.Network): +class Discriminator(tf.keras.Model): """GAN Discriminator. A network to differentiate between generated and real handwritten digits. @@ -56,19 +57,15 @@ class Discriminator(tfe.Network): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', - data_format=data_format, - activation=tf.tanh)) - self.pool1 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, - data_format=data_format, - activation=tf.tanh)) - self.pool2 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.flatten = self.track_layer(tf.layers.Flatten()) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) - self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + self.conv1 = layers.Conv2D( + 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) + self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.conv2 = layers.Conv2D( + 128, 5, data_format=data_format, activation=tf.tanh) + self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format) + self.flatten = layers.Flatten() + self.fc1 = layers.Dense(1024, activation=tf.tanh) + self.fc2 = layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. @@ -95,7 +92,7 @@ class Discriminator(tfe.Network): return x -class Generator(tfe.Network): +class Generator(tf.keras.Model): """Generator of handwritten digits similar to the ones in the MNIST dataset. """ @@ -116,18 +113,17 @@ class Generator(tfe.Network): else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] - self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, - activation=tf.tanh)) + self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) - self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( - 64, 4, strides=2, activation=None, data_format=data_format)) + self.conv1 = layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) - self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( - 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + self.conv2 = layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): """Return a batch of generated images. @@ -168,7 +164,8 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): """ loss_on_real = tf.losses.sigmoid_cross_entropy( - tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + tf.ones_like(discriminator_real_outputs), + discriminator_real_outputs, label_smoothing=0.25) loss_on_generated = tf.losses.sigmoid_cross_entropy( tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) @@ -198,9 +195,9 @@ def generator_loss(discriminator_gen_outputs): return loss -def train_one_epoch(generator, discriminator, - generator_optimizer, discriminator_optimizer, - dataset, log_interval, noise_dim): +def train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, dataset, step_counter, + log_interval, noise_dim): """Trains `generator` and `discriminator` models on `dataset`. Args: @@ -209,7 +206,8 @@ def train_one_epoch(generator, discriminator, generator_optimizer: Optimizer to use for generator. discriminator_optimizer: Optimizer to use for discriminator. dataset: Dataset of images to train on. - log_interval: How many global steps to wait between logging and collecting + step_counter: An integer variable, used to write summaries regularly. + log_interval: How many steps to wait between logging and collecting summaries. noise_dim: Dimension of noise vector to use. """ @@ -218,18 +216,23 @@ def train_one_epoch(generator, discriminator, total_discriminator_loss = 0.0 for (batch_index, images) in enumerate(tfe.Iterator(dataset)): with tf.device('/cpu:0'): - tf.assign_add(tf.train.get_global_step(), 1) + tf.assign_add(step_counter, 1) - with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + with tf.contrib.summary.record_summaries_every_n_global_steps( + log_interval, global_step=step_counter): current_batch_size = images.shape[0] - noise = tf.random_uniform(shape=[current_batch_size, noise_dim], - minval=-1., maxval=1., seed=batch_index) + noise = tf.random_uniform( + shape=[current_batch_size, noise_dim], + minval=-1., + maxval=1., + seed=batch_index) with tfe.GradientTape(persistent=True) as g: generated_images = generator(noise) - tf.contrib.summary.image('generated_images', - tf.reshape(generated_images, [-1, 28, 28, 1]), - max_images=10) + tf.contrib.summary.image( + 'generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) discriminator_gen_outputs = discriminator(generated_images) discriminator_real_outputs = discriminator(images) @@ -244,18 +247,16 @@ def train_one_epoch(generator, discriminator, discriminator_grad = g.gradient(discriminator_loss_val, discriminator.variables) - with tf.variable_scope('generator'): - generator_optimizer.apply_gradients(zip(generator_grad, - generator.variables)) - with tf.variable_scope('discriminator'): - discriminator_optimizer.apply_gradients(zip(discriminator_grad, - discriminator.variables)) + generator_optimizer.apply_gradients( + zip(generator_grad, generator.variables)) + discriminator_optimizer.apply_gradients( + zip(discriminator_grad, discriminator.variables)) if log_interval and batch_index > 0 and batch_index % log_interval == 0: print('Batch #%d\tAverage Generator Loss: %.6f\t' - 'Average Discriminator Loss: %.6f' % ( - batch_index, total_generator_loss/batch_index, - total_discriminator_loss/batch_index)) + 'Average Discriminator Loss: %.6f' % + (batch_index, total_generator_loss / batch_index, + total_discriminator_loss / batch_index)) def main(_): @@ -266,18 +267,18 @@ def main(_): # Load the datasets data = input_data.read_data_sets(FLAGS.data_dir) - dataset = (tf.data.Dataset - .from_tensor_slices(data.train.images) - .shuffle(60000) - .batch(FLAGS.batch_size)) - - # Create the models and optimizers - generator = Generator(data_format) - discriminator = Discriminator(data_format) - with tf.variable_scope('generator'): - generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) - with tf.variable_scope('discriminator'): - discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + dataset = ( + tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000) + .batch(FLAGS.batch_size)) + + # Create the models and optimizers. + model_objects = { + 'generator': Generator(data_format), + 'discriminator': Discriminator(data_format), + 'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), + 'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr), + 'step_counter': tf.train.get_or_create_global_step(), + } # Prepare summary writer and checkpoint info summary_writer = tf.contrib.summary.create_summary_file_writer( @@ -286,28 +287,22 @@ def main(_): latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) if latest_cpkt: print('Using latest checkpoint at ' + latest_cpkt) + checkpoint = tfe.Checkpoint(**model_objects) + # Restore variables on creation if a checkpoint exists. + checkpoint.restore(latest_cpkt) with tf.device(device): - for epoch in range(1, 101): - with tfe.restore_variables_on_create(latest_cpkt): - global_step = tf.train.get_or_create_global_step() - start = time.time() - with summary_writer.as_default(): - train_one_epoch(generator, discriminator, generator_optimizer, - discriminator_optimizer, - dataset, FLAGS.log_interval, FLAGS.noise) - end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) - - all_variables = ( - generator.variables - + discriminator.variables - + generator_optimizer.variables() - + discriminator_optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) + for _ in range(100): + start = time.time() + with summary_writer.as_default(): + train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval, + noise_dim=FLAGS.noise, **model_objects) + end = time.time() + checkpoint.save(checkpoint_prefix) + print('\nTrain time for epoch #%d (step %d): %f' % + (checkpoint.save_counter.numpy(), + checkpoint.step_counter.numpy(), + end - start)) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py index 4a3ca8d82bc2619b05a734f6d2e58431c1a45995..bd35e50c1f434d167c5a8c5aa7d224912523ce28 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -62,7 +62,7 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): for _ in range(measure_batches)] measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images) - tf.train.get_or_create_global_step() + step_counter = tf.train.get_or_create_global_step() with tf.device(device()): # Create the models and optimizers generator = mnist.Generator(data_format()) @@ -78,13 +78,15 @@ class MnistEagerGanBenchmark(tf.test.Benchmark): # warm up mnist.train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, - burn_dataset, log_interval=SUMMARY_INTERVAL, + burn_dataset, step_counter, + log_interval=SUMMARY_INTERVAL, noise_dim=NOISE_DIM) # measure start = time.time() mnist.train_one_epoch(generator, discriminator, generator_optimizer, discriminator_optimizer, - measure_dataset, log_interval=SUMMARY_INTERVAL, + measure_dataset, step_counter, + log_interval=SUMMARY_INTERVAL, noise_dim=NOISE_DIM) self._report('train', start, measure_batches, batch_size) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index f86331af6f7928f0f86c888e22706c6e0a5978b2..2f6cfdf31e852d5d69a7a87980c9a441da504cf2 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -22,6 +22,7 @@ cuda_py_test( ":linear_regression", "//tensorflow:tensorflow_py", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) cuda_py_test( diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 6ce4de6ee0bf50400eff339ac04e132252a2b53e..4e1380afb2e6e722de65c691d4fbf44621072e87 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -32,24 +32,16 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe +layers = tf.keras.layers -class LinearModel(tfe.Network): - """A TensorFlow linear regression model. - Uses TensorFlow's eager execution. - - For those familiar with TensorFlow graphs, notice the absence of - `tf.Session`. The `forward()` method here immediately executes and - returns output values. The `loss()` method immediately compares the - output of `forward()` with the target and returns the MSE loss value. - The `fit()` performs gradient-descent training on the model's weights - and bias. - """ +class LinearModel(tf.keras.Model): + """A TensorFlow linear regression model.""" def __init__(self): """Constructs a LinearModel object.""" super(LinearModel, self).__init__() - self._hidden_layer = self.track_layer(tf.layers.Dense(1)) + self._hidden_layer = layers.Dense(1) def call(self, xs): """Invoke the linear model. @@ -64,7 +56,7 @@ class LinearModel(tfe.Network): def mean_square_loss(model, xs, ys): - return tf.reduce_mean(tf.square(model(xs) - ys)) + return tf.reduce_mean(tf.square(tf.subtract(model(xs), ys))) def fit(model, dataset, optimizer, verbose=False, logdir=None): diff --git a/tensorflow/contrib/eager/python/examples/mnist/BUILD b/tensorflow/contrib/eager/python/examples/mnist/BUILD deleted file mode 100644 index c61ec2dbae60a782c0e6589701554b045dcb92ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -py_binary( - name = "mnist", - srcs = ["mnist.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/eager/python:tfe", - "//tensorflow/examples/tutorials/mnist:input_data", - ], -) - -cuda_py_test( - name = "mnist_test", - srcs = ["mnist_test.py"], - additional_deps = [ - ":mnist", - "//tensorflow/contrib/eager/python:tfe", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "mnist_graph_test", - srcs = ["mnist_graph_test.py"], - additional_deps = [ - ":mnist", - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/mnist/README.md b/tensorflow/contrib/eager/python/examples/mnist/README.md index e987996b88ccf54a322749aadec4f9840760a90f..d1c079ff6b5cb187bbcfe2742293982b1bedd2d4 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/README.md +++ b/tensorflow/contrib/eager/python/examples/mnist/README.md @@ -1,10 +1 @@ -Classification model for the MNIST dataset using eager execution. - -To run: - -``` -python mnist.py -``` - -`mnist_graph_test.py` demonstrates that the same code that is executed eagerly -in `mnist.py` is used to construct a TensorFlow graph. +See https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py deleted file mode 100644 index 772f59562ba27cce510c82681f491d005298f44c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A deep MNIST classifier using convolutional layers. - -Sample usage: - python mnist.py --help -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import sys -import time - -import tensorflow as tf - -import tensorflow.contrib.eager as tfe -from tensorflow.examples.tutorials.mnist import input_data - -FLAGS = None - - -class MNISTModel(tfe.Network): - """MNIST Network. - - Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py - and - https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py - - But written using the tf.layers API. - """ - - def __init__(self, data_format): - """Creates a model for classifying a hand-written digit. - - Args: - data_format: Either 'channels_first' or 'channels_last'. - 'channels_first' is typically faster on GPUs while 'channels_last' is - typically faster on CPUs. See - https://www.tensorflow.org/performance/performance_guide#data_formats - """ - super(MNISTModel, self).__init__(name='') - if data_format == 'channels_first': - self._input_shape = [-1, 1, 28, 28] - else: - assert data_format == 'channels_last' - self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer( - tf.layers.Conv2D(32, 5, data_format=data_format, activation=tf.nn.relu)) - self.conv2 = self.track_layer( - tf.layers.Conv2D(64, 5, data_format=data_format, activation=tf.nn.relu)) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.nn.relu)) - self.fc2 = self.track_layer(tf.layers.Dense(10)) - self.dropout = self.track_layer(tf.layers.Dropout(0.5)) - self.max_pool2d = self.track_layer( - tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='SAME', data_format=data_format)) - - def call(self, inputs, training): - """Computes labels from inputs. - - Users should invoke __call__ to run the network, which delegates to this - method (and not call this method directly). - - Args: - inputs: A batch of images as a Tensor with shape [batch_size, 784]. - training: True if invoked in the context of training (causing dropout to - be applied). False otherwise. - - Returns: - A Tensor with shape [batch_size, 10] containing the predicted logits - for each image in the batch, for each of the 10 classes. - """ - - x = tf.reshape(inputs, self._input_shape) - x = self.conv1(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.max_pool2d(x) - x = tf.layers.flatten(x) - x = self.fc1(x) - x = self.dropout(x, training=training) - x = self.fc2(x) - return x - - -def loss(predictions, labels): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=predictions, labels=labels)) - - -def compute_accuracy(predictions, labels): - return tf.reduce_sum( - tf.cast( - tf.equal( - tf.argmax(predictions, axis=1, - output_type=tf.int64), - tf.argmax(labels, axis=1, - output_type=tf.int64)), - dtype=tf.float32)) / float(predictions.shape[0].value) - - -def train_one_epoch(model, optimizer, dataset, log_interval=None): - """Trains model on `dataset` using `optimizer`.""" - - tf.train.get_or_create_global_step() - - for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): - with tf.contrib.summary.record_summaries_every_n_global_steps(10): - with tfe.GradientTape() as tape: - prediction = model(images, training=True) - loss_value = loss(prediction, labels) - tf.contrib.summary.scalar('loss', loss_value) - tf.contrib.summary.scalar('accuracy', - compute_accuracy(prediction, labels)) - grads = tape.gradient(loss_value, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) - if log_interval and batch % log_interval == 0: - print('Batch #%d\tLoss: %.6f' % (batch, loss_value)) - - -def test(model, dataset): - """Perform an evaluation of `model` on the examples from `dataset`.""" - avg_loss = tfe.metrics.Mean('loss') - accuracy = tfe.metrics.Accuracy('accuracy') - - for (images, labels) in tfe.Iterator(dataset): - predictions = model(images, training=False) - avg_loss(loss(predictions, labels)) - accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64), - tf.argmax(labels, axis=1, output_type=tf.int64)) - print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' % - (avg_loss.result(), 100 * accuracy.result())) - with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar('loss', avg_loss.result()) - tf.contrib.summary.scalar('accuracy', accuracy.result()) - - -def load_data(data_dir): - """Returns training and test tf.data.Dataset objects.""" - data = input_data.read_data_sets(data_dir, one_hot=True) - train_ds = tf.data.Dataset.from_tensor_slices((data.train.images, - data.train.labels)) - test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels)) - return (train_ds, test_ds) - - -def main(_): - tfe.enable_eager_execution() - - (device, data_format) = ('/gpu:0', 'channels_first') - if FLAGS.no_gpu or tfe.num_gpus() <= 0: - (device, data_format) = ('/cpu:0', 'channels_last') - print('Using device %s, and data format %s.' % (device, data_format)) - - # Load the datasets - (train_ds, test_ds) = load_data(FLAGS.data_dir) - train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size) - - # Create the model and optimizer - model = MNISTModel(data_format) - optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) - - if FLAGS.output_dir: - train_dir = os.path.join(FLAGS.output_dir, 'train') - test_dir = os.path.join(FLAGS.output_dir, 'eval') - tf.gfile.MakeDirs(FLAGS.output_dir) - else: - train_dir = None - test_dir = None - summary_writer = tf.contrib.summary.create_file_writer( - train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_file_writer( - test_dir, flush_millis=10000, name='test') - checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') - - with tf.device(device): - for epoch in range(1, 11): - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(FLAGS.checkpoint_dir)): - global_step = tf.train.get_or_create_global_step() - start = time.time() - with summary_writer.as_default(): - train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval) - end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) - with test_summary_writer.as_default(): - test(model, test_ds) - all_variables = ( - model.variables - + optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--data-dir', - type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - parser.add_argument( - '--batch-size', - type=int, - default=64, - metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument( - '--log-interval', - type=int, - default=10, - metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument( - '--output_dir', - type=str, - default=None, - metavar='N', - help='Directory to write TensorBoard summaries') - parser.add_argument( - '--checkpoint_dir', - type=str, - default='/tmp/tensorflow/mnist/checkpoints/', - metavar='N', - help='Directory to save checkpoints in (once per epoch)') - parser.add_argument( - '--lr', - type=float, - default=0.01, - metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument( - '--momentum', - type=float, - default=0.5, - metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument( - '--no-gpu', - action='store_true', - default=False, - help='disables GPU usage even if a GPU is available') - - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py deleted file mode 100644 index 1af26553120b34d4682b17b1c29c81dc65e421d4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.mnist import mnist - - -def data_format(): - return "channels_first" if tf.test.is_gpu_available() else "channels_last" - - -class MNISTGraphTest(tf.test.TestCase): - - def testTrainGraph(self): - # The MNISTModel class can be executed eagerly (as in mnist.py and - # mnist_test.py) and also be used to construct a TensorFlow graph, which is - # then trained in a session. - with tf.Graph().as_default(): - # Generate some random data. - batch_size = 64 - images = np.random.randn(batch_size, 784).astype(np.float32) - digits = np.random.randint(low=0, high=10, size=batch_size) - labels = np.zeros((batch_size, 10)) - labels[np.arange(batch_size), digits] = 1. - - # Create a model, optimizer, and dataset as would be done - # for eager execution as well. - model = mnist.MNISTModel(data_format()) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) - dataset = tf.data.Dataset.from_tensors((images, labels)) - - # Define the loss tensor (as opposed to a loss function when - # using eager execution). - (images, labels) = dataset.make_one_shot_iterator().get_next() - predictions = model(images, training=True) - loss = mnist.loss(predictions, labels) - - train_op = optimizer.minimize(loss) - init = tf.global_variables_initializer() - with tf.Session() as sess: - # Variables have to be initialized in the session. - sess.run(init) - # Train using the optimizer. - sess.run(train_op) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py deleted file mode 100644 index 136085eba21284a42282395e54f32c33bf63b5c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import tensorflow.contrib.eager as tfe -from tensorflow.contrib.eager.python.examples.mnist import mnist - - -def device(): - return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" - - -def data_format(): - return "channels_first" if tfe.num_gpus() else "channels_last" - - -def random_dataset(): - batch_size = 64 - images = tf.random_normal([batch_size, 784]) - digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32) - labels = tf.one_hot(digits, 10) - return tf.data.Dataset.from_tensors((images, labels)) - - -def train_one_epoch(defun=False): - model = mnist.MNISTModel(data_format()) - if defun: - model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.train_one_epoch(model, optimizer, dataset) - - -def evaluate(defun=False): - model = mnist.MNISTModel(data_format()) - dataset = random_dataset() - if defun: - model.call = tfe.defun(model.call) - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.test(model, dataset) - - -class MNISTTest(tf.test.TestCase): - - def testTrainOneEpoch(self): - train_one_epoch(defun=False) - - def testTest(self): - evaluate(defun=False) - - def testTrainOneEpochWithDefunCall(self): - train_one_epoch(defun=True) - - def testTestWithDefunCall(self): - evaluate(defun=True) - - -if __name__ == "__main__": - tfe.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index 9982fdb07eefa665379e7be095f4f8017d92cf97..a28bc8a43d7c90737c9baf9a634d736e9de52948 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -27,10 +27,11 @@ from __future__ import print_function import functools import tensorflow as tf -import tensorflow.contrib.eager as tfe +layers = tf.keras.layers -class _IdentityBlock(tfe.Network): + +class _IdentityBlock(tf.keras.Model): """_IdentityBlock is the block that has no conv layer at shortcut. Args: @@ -50,31 +51,24 @@ class _IdentityBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - data_format=data_format, - name=conv_name_base + '2b')) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) + self.conv2a = layers.Conv2D( + filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) + self.bn2a = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = layers.Conv2D( + filters2, + kernel_size, + padding='same', + data_format=data_format, + name=conv_name_base + '2b') + self.bn2b = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -92,7 +86,7 @@ class _IdentityBlock(tfe.Network): return tf.nn.relu(x) -class _ConvBlock(tfe.Network): +class _ConvBlock(tf.keras.Model): """_ConvBlock is the block that has a conv layer at shortcut. Args: @@ -121,41 +115,35 @@ class _ConvBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - strides=strides, - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - name=conv_name_base + '2b', - data_format=data_format)) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) - - self.conv_shortcut = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - strides=strides, - name=conv_name_base + '1', - data_format=data_format)) - self.bn_shortcut = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')) + self.conv2a = layers.Conv2D( + filters1, (1, 1), + strides=strides, + name=conv_name_base + '2a', + data_format=data_format) + self.bn2a = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = layers.Conv2D( + filters2, + kernel_size, + padding='same', + name=conv_name_base + '2b', + data_format=data_format) + self.bn2b = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') + + self.conv_shortcut = layers.Conv2D( + filters3, (1, 1), + strides=strides, + name=conv_name_base + '1', + data_format=data_format) + self.bn_shortcut = layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '1') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -176,7 +164,8 @@ class _ConvBlock(tfe.Network): return tf.nn.relu(x) -class ResNet50(tfe.Network): +# pylint: disable=not-callable +class ResNet50(tf.keras.Model): """Instantiates the ResNet50 architecture. Args: @@ -220,32 +209,28 @@ class ResNet50(tfe.Network): self.include_top = include_top def conv_block(filters, stage, block, strides=(2, 2)): - l = _ConvBlock( + return _ConvBlock( 3, filters, stage=stage, block=block, data_format=data_format, strides=strides) - return self.track_layer(l) def id_block(filters, stage, block): - l = _IdentityBlock( + return _IdentityBlock( 3, filters, stage=stage, block=block, data_format=data_format) - return self.track_layer(l) - - self.conv1 = self.track_layer( - tf.layers.Conv2D( - 64, (7, 7), - strides=(2, 2), - data_format=data_format, - padding='same', - name='conv1')) + + self.conv1 = layers.Conv2D( + 64, (7, 7), + strides=(2, 2), + data_format=data_format, + padding='same', + name='conv1') bn_axis = 1 if data_format == 'channels_first' else 3 - self.bn_conv1 = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')) - self.max_pool = self.track_layer( - tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format)) + self.bn_conv1 = layers.BatchNormalization(axis=bn_axis, name='bn_conv1') + self.max_pool = layers.MaxPooling2D( + (3, 3), strides=(2, 2), data_format=data_format) self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) self.l2b = id_block([64, 64, 256], stage=2, block='b') @@ -267,13 +252,12 @@ class ResNet50(tfe.Network): self.l5b = id_block([512, 512, 2048], stage=5, block='b') self.l5c = id_block([512, 512, 2048], stage=5, block='c') - self.avg_pool = self.track_layer( - tf.layers.AveragePooling2D( - (7, 7), strides=(7, 7), data_format=data_format)) + self.avg_pool = layers.AveragePooling2D( + (7, 7), strides=(7, 7), data_format=data_format) if self.include_top: - self.fc1000 = self.track_layer( - tf.layers.Dense(classes, name='fc1000')) + self.flatten = layers.Flatten() + self.fc1000 = layers.Dense(classes, name='fc1000') else: reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] reduction_indices = tf.constant(reduction_indices) @@ -288,7 +272,7 @@ class ResNet50(tfe.Network): else: self.global_pooling = None - def call(self, input_tensor, training=False): + def call(self, input_tensor, training): x = self.conv1(input_tensor) x = self.bn_conv1(x, training=training) x = tf.nn.relu(x) @@ -317,7 +301,7 @@ class ResNet50(tfe.Network): x = self.avg_pool(x) if self.include_top: - return self.fc1000(tf.layers.flatten(x)) + return self.fc1000(self.flatten(x)) elif self.global_pooling: return self.global_pooling(x) else: diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 23317886e712323f4b520000e0fd372734fc53a1..551c76b0df71c88919df9cd6d81b4176b23b0ba3 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -55,7 +55,7 @@ class ResNet50GraphTest(tf.test.TestCase): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() @@ -114,7 +114,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 0ff8746884c288f824f5f22ab4c550370d0e0302..d6923293a374f29ab77be70fa9fea44efd1ea40b 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -64,28 +64,35 @@ def train_one_step(model, images, labels, optimizer): class ResNet50Test(tf.test.TestCase): - def _apply(self, defun=False): + def _apply(self, defun=False, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) + tfe.async_wait() self.assertEqual((2, 1000), output.shape) def test_apply(self): self._apply(defun=False) + def test_apply_async(self): + self._apply(defun=False, execution_mode=tfe.ASYNC) + def test_apply_with_defun(self): self._apply(defun=True) + def test_apply_with_defun_async(self): + self._apply(defun=True, execution_mode=tfe.ASYNC) + def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) @@ -95,10 +102,10 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) self.assertEqual((2, 2048), output.shape) - def test_train(self): + def _test_train(self, execution_mode=None): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() @@ -106,15 +113,22 @@ class ResNet50Test(tf.test.TestCase): with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): - with tf.device(device): + with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) + tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') + def test_train(self): + self._test_train() + + def test_train_async(self): + self._test_train(execution_mode=tfe.ASYNC) + def test_no_garbage(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) @@ -183,59 +197,84 @@ class ResNet50Benchmarks(tf.test.Benchmark): # a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False): - device, data_format = device_and_data_format() - model = resnet50.ResNet50(data_format) - if defun: - model.call = tfe.defun(model.call) - batch_size = 64 - num_burn = 5 - num_iters = 30 - with tf.device(device): - images, _ = random_batch(batch_size) - for _ in xrange(num_burn): - model(images).cpu() - gc.collect() - start = time.time() - for _ in xrange(num_iters): - model(images).cpu() - self._report(label, start, num_iters, device, batch_size, data_format) - - def benchmark_eager_apply(self): - self._benchmark_eager_apply('eager_apply', defun=False) - - def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) - - def _benchmark_eager_train(self, label, make_iterator, defun=False): - device, data_format = device_and_data_format() - for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size) - num_burn = 3 - num_iters = 10 + def _benchmark_eager_apply(self, label, defun=False, execution_mode=None): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) if defun: model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(0.1) - + batch_size = 64 + num_burn = 5 + num_iters = 30 with tf.device(device): - iterator = make_iterator((images, labels)) + images, _ = random_batch(batch_size) for _ in xrange(num_burn): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() gc.collect() - start = time.time() for _ in xrange(num_iters): - (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) - self._force_gpu_sync() + model(images, training=False).cpu() + if execution_mode: + tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format) + def benchmark_eager_apply(self): + self._benchmark_eager_apply('eager_apply', defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + defun=False, + execution_mode=None): + with tfe.execution_mode(execution_mode): + device, data_format = device_and_data_format() + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size) + num_burn = 3 + num_iters = 10 + model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) + optimizer = tf.train.GradientDescentOptimizer(0.1) + + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_gpu_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_gpu_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + def benchmark_eager_train(self): self._benchmark_eager_train('eager_train', MockIterator, defun=False) + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + 'eager_train_async', + MockIterator, + defun=False, + execution_mode=tfe.ASYNC) + def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( 'eager_train_with_defun', MockIterator, defun=True) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index aa87b94e7b0876e65405f6bcb2d6aabde36582bf..492adbe1d80941f9df96d6636e4933d11239408e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -60,6 +60,7 @@ import functools import os import sys import time +import urllib import six import tensorflow as tf @@ -72,6 +73,8 @@ try: except ImportError: HAS_MATPLOTLIB = False +layers = tf.keras.layers + def parse(line): """Parse a line from the colors dataset.""" @@ -89,13 +92,35 @@ def parse(line): return rgb, chars, length +def maybe_download(filename, work_directory, source_url): + """Download the data from source url, unless it's already here. + + Args: + filename: string, name of the file in the directory. + work_directory: string, path to working directory. + source_url: url to download from if file doesn't exist. + + Returns: + Path to resulting file. + """ + if not tf.gfile.Exists(work_directory): + tf.gfile.MakeDirs(work_directory) + filepath = os.path.join(work_directory, filename) + if not tf.gfile.Exists(filepath): + temp_file_name, _ = urllib.request.urlretrieve(source_url) + tf.gfile.Copy(temp_file_name, filepath) + with tf.gfile.GFile(filepath) as f: + size = f.size() + print("Successfully downloaded", filename, size, "bytes.") + return filepath + + def load_dataset(data_dir, url, batch_size): """Loads the colors data at path into a PaddedDataset.""" # Downloads data at url into data_dir/basename(url). The dataset has a header # row (color_name, r, g, b) followed by comma-separated lines. - path = tf.contrib.learn.datasets.base.maybe_download( - os.path.basename(url), data_dir, url) + path = maybe_download(os.path.basename(url), data_dir, url) # This chain of commands loads our data by: # 1. skipping the header; (.skip(1)) @@ -109,7 +134,7 @@ def load_dataset(data_dir, url, batch_size): # pylint: disable=not-callable -class RNNColorbot(tfe.Network): +class RNNColorbot(tf.keras.Model): """Multi-layer (LSTM) RNN that regresses on real-valued vector labels. """ @@ -127,23 +152,20 @@ class RNNColorbot(tfe.Network): self.label_dimension = label_dimension self.keep_prob = keep_prob - # Note the calls to `track_layer` below; these calls register the layers as - # network components that house trainable variables. - self.cells = [ - self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size)) - for size in rnn_cell_sizes - ] - self.relu = self.track_layer( - tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu")) + self.cells = self._add_cells( + [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) + self.relu = layers.Dense( + label_dimension, activation=tf.nn.relu, name="relu") - def call(self, chars, sequence_length, training=False): + def call(self, inputs, training=False): """Implements the RNN logic and prediction generation. Args: - chars: a Tensor of dimension [batch_size, time_steps, 256] holding a - batch of one-hot encoded color names - sequence_length: a Tensor of dimension [batch_size] holding the length - of each character sequence (i.e., color name) + inputs: A tuple (chars, sequence_length), where chars is a batch of + one-hot encoded color names represented as a Tensor with dimensions + [batch_size, time_steps, 256] and sequence_length holds the length + of each character sequence (color name) as a Tensor with dimension + [batch_size]. training: whether the invocation is happening during training Returns: @@ -151,6 +173,7 @@ class RNNColorbot(tfe.Network): passing chars through a multi-layer RNN and applying a ReLU to the final hidden state. """ + (chars, sequence_length) = inputs # Transpose the first and second dimensions so that chars is of shape # [time_steps, batch_size, dimension]. chars = tf.transpose(chars, [1, 0, 2]) @@ -181,6 +204,14 @@ class RNNColorbot(tfe.Network): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) + def _add_cells(self, cells): + # "Magic" required for keras.Model classes to track all the variables in + # a list of layers.Layer objects. + # TODO(ashankar): Figure out API so user code doesn't have to do this. + for i, c in enumerate(cells): + setattr(self, "cell-%d" % i, c) + return cells + def loss(labels, predictions): """Computes mean squared loss.""" @@ -191,7 +222,7 @@ def test(model, eval_data): """Computes the average loss on eval_data, which should be a Dataset.""" avg_loss = tfe.metrics.Mean("loss") for (labels, chars, sequence_length) in tfe.Iterator(eval_data): - predictions = model(chars, sequence_length, training=False) + predictions = model((chars, sequence_length), training=False) avg_loss(loss(labels, predictions)) print("eval/loss: %.6f\n" % avg_loss.result()) with tf.contrib.summary.always_record_summaries(): @@ -204,7 +235,7 @@ def train_one_epoch(model, optimizer, train_data, log_interval=10): tf.train.get_or_create_global_step() def model_loss(labels, chars, sequence_length): - predictions = model(chars, sequence_length, training=True) + predictions = model((chars, sequence_length), training=True) loss_value = loss(labels, predictions) tf.contrib.summary.scalar("loss", loss_value) return loss_value @@ -277,7 +308,7 @@ def main(_): (chars, length) = (tf.identity(chars), tf.identity(length)) chars = tf.expand_dims(chars, 0) length = tf.expand_dims(length, 0) - preds = tf.unstack(model(chars, length, training=False)[0]) + preds = tf.unstack(model((chars, length), training=False)[0]) # Predictions cannot be negative, as they are generated by a ReLU layer; # they may, however, be greater than 1. diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 5c5c59c87744f4ffa6db90e5d8d3aa3bc8132756..be5d60449d7e08c99cc28e76befce56f468c77fd 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -38,22 +38,26 @@ import tensorflow as tf from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.eager.python import tfe +layers = tf.keras.layers -class RNN(tfe.Network): + +class RNN(tf.keras.Model): """A static RNN. - Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer. + Similar to tf.nn.static_rnn, implemented as a class. """ def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - for _ in range(num_layers): - self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)) + self.cells = self._add_cells([ + tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) + for _ in range(num_layers) + ]) def call(self, input_seq, training): batch_size = int(input_seq.shape[1]) - for c in self.layers: + for c in self.cells: state = c.zero_state(batch_size, tf.float32) outputs = [] input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0) @@ -64,10 +68,22 @@ class RNN(tfe.Network): input_seq = tf.stack(outputs, axis=0) if training: input_seq = tf.nn.dropout(input_seq, self.keep_ratio) - return input_seq, None - - -class Embedding(tf.layers.Layer): + # Returning a list instead of a single tensor so that the line: + # y = self.rnn(y, ...)[0] + # in PTBModel.call works for both this RNN and CudnnLSTM (which returns a + # tuple (output, output_states). + return [input_seq] + + def _add_cells(self, cells): + # "Magic" required for keras.Model classes to track all the variables in + # a list of Layer objects. + # TODO(ashankar): Figure out API so user code doesn't have to do this. + for i, c in enumerate(cells): + setattr(self, "cell-%d" % i, c) + return cells + + +class Embedding(layers.Layer): """An Embedding layer.""" def __init__(self, vocab_size, embedding_dim, **kwargs): @@ -87,7 +103,8 @@ class Embedding(tf.layers.Layer): return tf.nn.embedding_lookup(self.embedding, x) -class PTBModel(tfe.Network): +# pylint: disable=not-callable +class PTBModel(tf.keras.Model): """LSTM for word language modeling. Model described in: @@ -109,19 +126,16 @@ class PTBModel(tfe.Network): self.keep_ratio = 1 - dropout_ratio self.use_cudnn_rnn = use_cudnn_rnn - self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim)) + self.embedding = Embedding(vocab_size, embedding_dim) if self.use_cudnn_rnn: self.rnn = cudnn_rnn.CudnnLSTM( num_layers, hidden_dim, dropout=dropout_ratio) else: self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) - self.track_layer(self.rnn) - self.linear = self.track_layer( - tf.layers.Dense( - vocab_size, - kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))) + self.linear = layers.Dense( + vocab_size, kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1)) self._output_shape = [-1, embedding_dim] def call(self, input_seq, training): @@ -136,7 +150,7 @@ class PTBModel(tfe.Network): y = self.embedding(input_seq) if training: y = tf.nn.dropout(y, self.keep_ratio) - y, _ = self.rnn(y, training=training) + y = self.rnn(y, training=training)[0] return self.linear(tf.reshape(y, self._output_shape)) @@ -148,7 +162,7 @@ def clip_gradients(grads_and_vars, clip_ratio): def loss_fn(model, inputs, targets, training): labels = tf.reshape(targets, [-1]) - outputs = model(inputs, training) + outputs = model(inputs, training=training) return tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=outputs)) @@ -301,32 +315,37 @@ def main(_): have_gpu = tfe.num_gpus() > 0 use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(FLAGS.logdir)): - with tf.device("/device:GPU:0" if have_gpu else None): - # Make learning_rate a Variable so it can be included in the checkpoint - # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") - sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) - model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, - FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, - use_cudnn_rnn) - optimizer = tf.train.GradientDescentOptimizer(learning_rate) - - best_loss = None - for _ in range(FLAGS.epoch): - train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) - eval_loss = evaluate(model, eval_data) - if not best_loss or eval_loss < best_loss: - if FLAGS.logdir: - tfe.Saver(model.trainable_weights + [learning_rate]).save( - os.path.join(FLAGS.logdir, "ckpt")) - best_loss = eval_loss - else: - learning_rate.assign(learning_rate / 4.0) - sys.stderr.write("eval_loss did not reduce in this epoch, " - "changing learning rate to %f for the next epoch\n" % - learning_rate.numpy()) + with tf.device("/device:GPU:0" if have_gpu else None): + # Make learning_rate a Variable so it can be included in the checkpoint + # and we can resume training with the last saved learning_rate. + learning_rate = tfe.Variable(20.0, name="learning_rate") + model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, + FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, + use_cudnn_rnn) + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + checkpoint = tfe.Checkpoint( + learning_rate=learning_rate, model=model, + # GradientDescentOptimizer has no state to checkpoint, but noting it + # here lets us swap in an optimizer that does. + optimizer=optimizer) + # Restore existing variables now (learning_rate), and restore new variables + # on creation if a checkpoint exists. + checkpoint.restore(tf.train.latest_checkpoint(FLAGS.logdir)) + sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) + + best_loss = None + for _ in range(FLAGS.epoch): + train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) + eval_loss = evaluate(model, eval_data) + if not best_loss or eval_loss < best_loss: + if FLAGS.logdir: + checkpoint.save(os.path.join(FLAGS.logdir, "ckpt")) + best_loss = eval_loss + else: + learning_rate.assign(learning_rate / 4.0) + sys.stderr.write("eval_loss did not reduce in this epoch, " + "changing learning rate to %f for the next epoch\n" % + learning_rate.numpy()) if __name__ == "__main__": diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index a1f8a759e2a556bc219f0aa13942f293c4f34cfa..5966f1d4873e8e77b3ad5914da7bfc7e69d4e341 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -38,5 +38,9 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], - tags = ["no_pip"], # because spinn.py is under third_party/. + tags = [ + "no-internal-py3", # flaky + "no_cuda_on_cpu_tap", + "no_pip", # because spinn.py is under third_party/. + ], ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index eefc06d90d83b61d07a613643c913d3833a5f2c1..f825a2a7363fbe144162eca96398920ead0c4e50 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -34,6 +34,7 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util +from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.training import checkpoint_utils @@ -369,7 +370,7 @@ class SpinnTest(test_util.TensorFlowTestCase): inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )")) logits = spinn.train_or_infer_spinn( embed, word2index, None, None, None, config) - self.assertEqual(np.float32, logits.dtype) + self.assertEqual(tf.float32, logits.dtype) self.assertEqual((3,), logits.shape) def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self): @@ -417,12 +418,17 @@ class SpinnTest(test_util.TensorFlowTestCase): if event.summary.value and event.summary.value[0].tag == "train/loss"] self.assertEqual(config.epochs, len(train_losses)) - self.assertLess(train_losses[-1], train_losses[0]) # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - ckpt_variable_names = [ - item[0] for item in checkpoint_utils.list_variables(config.logdir)] + object_graph_string = checkpoint_utils.load_variable( + config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") + object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() + object_graph.ParseFromString(object_graph_string) + ckpt_variable_names = set() + for node in object_graph.nodes: + for attribute in node.attributes: + ckpt_variable_names.add(attribute.full_name) self.assertIn("global_step", ckpt_variable_names) for v in trainer.variables: variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index ffc1d0332eae605ce0444a225e53baa68954cae0..2d2aba6908b168e0bf63f4706b6344cbb4ca82bd 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -1,903 +1,18 @@ -# TensorFlow Eager Execution - -## What is this? +# Eager execution Eager execution is a feature that makes TensorFlow execute operations -immediately: concrete values are returned, instead of a computational graph to -be executed later. - -As a result, enabling eager execution provides: - -- A [NumPy](http://www.numpy.org/)-like library for numerical computation with - support for GPU acceleration and automatic differentiation. -- A flexible platform for machine learning research and experimentation. - -Eager execution is under active development. This guide walks through an -alpha/preview release. In particular, not all TensorFlow APIs currently work -with eager execution enabled, and some models may be slow to execute, compared -to models defined without using eager execution. - -## Installation - -Eager execution is included in TensorFlow versions 1.5 and above. -Installation instructions at https://www.tensorflow.org/install/ - -The contents of this guide are compatible with TensorFlow 1.5. -However, if you run into bugs that are fixed in source but not the -release, you may want to either either [building from -source](https://www.tensorflow.org/install/install_sources) -or the try latest nightly builds. The nightly builds are available as: - -- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and - -- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. - -For example, to run the latest nightly docker image: - -```sh -# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker -docker pull tensorflow/tensorflow:nightly-gpu -docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu - -# If you do not have a GPU, use the CPU-only image -docker pull tensorflow/tensorflow:nightly -docker run -it -p 8888:8888 tensorflow/tensorflow:nightly -``` - -And then visit http://localhost:8888 in your browser for a Jupyter notebook -environment. - -## Getting Started - -With TensorFlow installed, eager execution is enabled via a single call: - -```python -import tensorflow as tf - -import tensorflow.contrib.eager as tfe - -tfe.enable_eager_execution() -``` - -Enabling eager execution changes how TensorFlow functions behave (in particular, -`Tensor` objects will reference concrete values instead of being symbolic -handles to nodes in a computational graph). As a result, eager execution should -be enabled at the beginning of a program and cannot be disabled afterwards in -the same program. - -Code examples in the rest of this guide assume that eager execution has been -enabled. - -## A library for numerical computation - -A significant fraction of the [TensorFlow -API](https://www.tensorflow.org/api_docs/python/) consists of numerical -operations: -[arithmetic operations](https://www.tensorflow.org/api_guides/python/math_ops#Arithmetic_Operators), -[matrix operations](https://www.tensorflow.org/api_guides/python/math_ops#Matrix_Math_Functions), -[linear algebra operations](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), -etc. - -With eager execution enabled, these operations consume and return -multi-dimensional arrays as `Tensor` objects, similar to NumPy -[`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html). -For example: - -```python -# Multiply two 2x2 matrices -x = tf.matmul([[1, 2], - [3, 4]], - [[4, 5], - [6, 7]]) -# Add one to each element -# (tf.add supports broadcasting) -y = tf.add(x, 1) - -# Create a random random 5x3 matrix -z = tf.random_uniform([5, 3]) - -print(x) -print(y) -print(z) -``` - -Output: - -``` -tf.Tensor( -[[16 19] - [36 43]], shape=(2, 2), dtype=int32) -tf.Tensor( -[[17 20] - [37 44]], shape=(2, 2), dtype=int32) -tf.Tensor( -[[ 0.25058532 0.0929395 0.54113817] - [ 0.3108716 0.93350542 0.84909797] - [ 0.53081679 0.12788558 0.01767385] - [ 0.29725885 0.33540785 0.83588314] - [ 0.38877153 0.39720535 0.78914213]], shape=(5, 3), dtype=float32) -``` - -For convenience, these operations can also be triggered via operator overloading -of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`, -`-` to `tf.subtract`, `*` to `tf.multiply`, etc.: - -```python -x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1 -print(x) -``` - -Output: - -``` -tf.Tensor([ 3.], shape=(1,), dtype=float32) -``` - -### Converting to and from NumPy - -The operations above automatically convert Python objects (like lists of -numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used -as NumPy arrays by numpy operations. - -```python -import numpy as np - -x = tf.add(1, 1) # tf.Tensor with a value of 2 -y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2 -z = np.multiply(x, y) # numpy.int64 with a value of 4 -``` - -Alternatively, they can be explicitly converted using -[`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as -shown in the next example. - -Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain -its NumPy `ndarray` value. For example: - -```python -import numpy as np - -np_x = np.array(2., dtype=np.float32) -x = tf.constant(np_x) - -py_y = 3. -y = tf.constant(py_y) - -z = x + y + 1 - -print(z) -print(z.numpy()) -``` - -Output: - -``` -tf.Tensor(6.0, shape=(), dtype=float32) -6.0 -``` - -### GPU acceleration - -Many TensorFlow operations support GPU acceleration. With eager execution -enabled, [computation is *not* automatically -offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you -must explicitly specify when GPUs should be used. - -The simplest way to do this is to enclose your computation in a `with -tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function, -which returns the number of available GPUs. - -For example, consider this snippet to measure the time to multiply two 1000x1000 -matrices on CPU: - -```python -import time - -def measure(x): - # The very first time a GPU is used by TensorFlow, it is initialized. - # So exclude the first run from timing. - tf.matmul(x, x) - - start = time.time() - for i in range(10): - tf.matmul(x, x) - end = time.time() - - return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape) - -# Run on CPU: -with tf.device("/cpu:0"): - print("CPU: %s" % measure(tf.random_normal([1000, 1000]))) - -# If a GPU is available, run on GPU: -if tfe.num_gpus() > 0: - with tf.device("/gpu:0"): - print("GPU: %s" % measure(tf.random_normal([1000, 1000]))) -``` - -Output (exact numbers will depend on the characteristics of the hardware): - -```python -CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times -GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times -``` - -Alternatively, methods on the `Tensor` object can be used to explicitly copy the -`Tensor` to a different device. Operations are typically executed on the device -on which the inputs are placed. For example: - -```python -x = tf.random_normal([10, 10]) - -x_gpu0 = x.gpu() -x_cpu = x.cpu() - -_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU -_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0 - -if tfe.num_gpus() > 1: - x_gpu1 = x.gpu(1) - _ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1 -``` - -### Automatic Differentiation - -[Automatic -differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is -very useful when implementing many machine learning algorithms (e.g., -[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training -neural networks). For this purpose, TensorFlow eager execution provides an -[autograd](https://github.com/HIPS/autograd)-style API for automatic -differentiation. Specifically, the functions: - -- `tfe.gradients_function(f)`: Returns a Python function that computes the - derivatives of the Python function `f` with respect to its arguments. `f` - must return a scalar value. When the returned function is invoked, it - returns a list of `Tensor` objects (one element for each argument of `f`). -- `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`, - except that when the returned function is invoked, it returns the value of - `f` in addition to the list of derivatives of `f` with respect to its - arguments. - -These functions naturally apply to higher order differentiation as well. For -example: - -```python -def f(x): - return tf.multiply(x, x) # Or x * x -assert 9 == f(3.).numpy() - -df = tfe.gradients_function(f) -assert 6 == df(3.)[0].numpy() - -# Second order deriviative. -d2f = tfe.gradients_function(lambda x: df(x)[0]) -assert 2 == d2f(3.)[0].numpy() - -# Third order derivative. -d3f = tfe.gradients_function(lambda x : d2f(x)[0]) -assert 0 == d3f(3.)[0].numpy() -``` - -These functions can be used to train models. For example, consider the following -simple linear regression model: - -```python -def prediction(input, weight, bias): - return input * weight + bias - -# A toy dataset of points around 3 * x + 2 -NUM_EXAMPLES = 1000 -training_inputs = tf.random_normal([NUM_EXAMPLES]) -noise = tf.random_normal([NUM_EXAMPLES]) -training_outputs = training_inputs * 3 + 2 + noise - -# A loss function: Mean-squared error -def loss(weight, bias): - error = prediction(training_inputs, weight, bias) - training_outputs - return tf.reduce_mean(tf.square(error)) - -# Function that returns the derivative of loss with respect to -# weight and bias -grad = tfe.gradients_function(loss) - -# Train for 200 steps (starting from some random choice for W and B, on the same -# batch of data). -W = 5. -B = 10. -learning_rate = 0.01 -print("Initial loss: %f" % loss(W, B).numpy()) -for i in range(200): - (dW, dB) = grad(W, B) - W -= dW * learning_rate - B -= dB * learning_rate - if i % 20 == 0: - print("Loss at step %d: %f" % (i, loss(W, B).numpy())) -print("Final loss: %f" % loss(W, B).numpy()) -print("W, B = %f, %f" % (W.numpy(), B.numpy())) -``` - -Output: (the exact numbers may vary depending on the randomness in noise) - -``` -Initial loss: 66.730003 -Loss at step 0: 64.200096 -Loss at step 20: 29.872814 -Loss at step 40: 14.233772 -Loss at step 60: 7.090570 -Loss at step 80: 3.819887 -Loss at step 100: 2.318821 -Loss at step 120: 1.628385 -Loss at step 140: 1.310142 -Loss at step 160: 1.163167 -Loss at step 180: 1.095162 -Final loss: 1.064711 -W, B = 3.094944, 2.161383 -``` - -To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):` -block. (However, this particular model, with only two floating point parameters, -is unlikely to benefit from GPU acceleration.) - -### Customizing gradients - -One may want to define custom gradients for an operation, or for a function. -This may be useful for multiple reasons, including providing a more efficient -or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability) -gradient for a sequence of operations. - -For example, consider the function `log(1 + e^x)`, which commonly occurs in the -computation of cross entropy and log likelihoods. - -```python -def log1pexp(x): -  return tf.log(1 + tf.exp(x)) -grad_log1pexp = tfe.gradients_function(log1pexp) - -# Works fine at x = 0. -assert 0.5 == float(grad_log1pexp(0.)[0]) - -# Returns a `nan` at x = 100 due to numerical instability. -import math -assert math.isnan(float(grad_log1pexp(100.)[0])) -``` - -We can define a custom gradient for the above function that analytically -simplifies the gradient expression. - -```python -@tfe.custom_gradient -def log1pexp(x): -  e = tf.exp(x) -  def grad(dy): -    return dy * (1 - 1 / (1 + e)) -  return tf.log(1 + e), grad -grad_log1pexp = tfe.gradients_function(log1pexp) - -# Works as before at x = 0. -assert 0.5 == float(grad_log1pexp(0.)[0]) - -# But now works at x = 100 as well. -assert 1.0 == float(grad_log1pexp(100.)[0]) -``` -Also notice how the gradient function implementation reuses an expression -(`tf.exp(x)`) computed during the forward pass, hence making the gradient -computation more efficient by avoiding redundant computation. - -## Building and training models - -In practice, your computation may have many parameters to be optimized (by -computing derivatives). Encapsulating them into re-usable classes/objects -makes the code easier to follow than writing a single top-level function with -many arguments. - -In fact, eager execution encourages use of the [Keras](https://keras.io)-style -"Layer" classes in the -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) -module. - -Furthermore, you may want to apply more sophisticated techniques to compute -parameter updates, such as those in -[`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers) -implementations. - -This next section walks through using the same `Optimizer` and `Layer` APIs used -to build trainable TensorFlow graphs in an environment where eager execution is -enabled. - -### Variables and Optimizers - -`tfe.Variable` objects store mutable `Tensor` values that can be accessed during -training, making automatic differentiation easier. In particular, parameters of -a model can be encapsulated in Python classes as variables. - -`tfe.gradients_function(f)` introduced earlier computes the derivatives of `f` -with respect to its arguments. However, it requires all parameters of interest -to be arguments of `f`, which becomes cumbersome when `f` depends on a large -number of trainable parameters. - -`tfe.implicit_gradients` is an alternative function with some useful properties: - -- It computes the derivatives of `f` with respect to all the `tfe.Variable`s - used by `f`. -- When the returned function is invoked, it returns a list of - (gradient value, Variable object) tuples. - -Representing model parameters as `Variable` objects, along with the use of -`tfe.implicit_gradients`, typically results in better encapsulation. For -example, the linear regression model described above can be written into a -class: - -```python -class Model(object): - def __init__(self): - self.W = tfe.Variable(5., name='weight') - self.B = tfe.Variable(10., name='bias') - - def predict(self, inputs): - return inputs * self.W + self.B - - -# The loss function to be optimized -def loss(model, inputs, targets): - error = model.predict(inputs) - targets - return tf.reduce_mean(tf.square(error)) - -# A toy dataset of points around 3 * x + 2 -NUM_EXAMPLES = 1000 -training_inputs = tf.random_normal([NUM_EXAMPLES]) -noise = tf.random_normal([NUM_EXAMPLES]) -training_outputs = training_inputs * 3 + 2 + noise - -# Define: -# 1. A model -# 2. Derivatives of a loss function with respect to model parameters -# 3. A strategy for updating the variables based on the derivatives -model = Model() -grad = tfe.implicit_gradients(loss) -optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - -# The training loop -print("Initial loss: %f" % - loss(model, training_inputs, training_outputs).numpy()) -for i in range(201): - optimizer.apply_gradients(grad(model, training_inputs, training_outputs)) - if i % 20 == 0: - print("Loss at step %d: %f" % - (i, loss(model, training_inputs, training_outputs).numpy())) -print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy()) -print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy())) -``` - -Output: - -``` -Initial loss: 69.693184 -Loss at step 0: 66.987854 -Loss at step 20: 30.553387 -Loss at step 40: 14.250237 -Loss at step 60: 6.955020 -Loss at step 80: 3.690550 -Loss at step 100: 2.229739 -Loss at step 120: 1.576032 -Loss at step 140: 1.283496 -Loss at step 160: 1.152584 -Loss at step 180: 1.093999 -Final loss: 1.067780 -W, B = 3.0114281, 2.0865183 -``` - -Using `implicit_gradients` avoids the need to provide all the trainable -parameters of the model as arguments to the `loss` function. - -### Using Keras and the Layers API - -[Keras](https://keras.io) is a popular API for defining model structures. The -[`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) -module provides a set of building blocks for models and is implemented using the -`tf.layers.Layer` subclasses in the -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) -module. We encourage the use of these same building blocks when using -TensorFlow's eager execution feature. For example, the very same linear -regression model can be built using `tf.layers.Dense`: - -```python -class Model(object): - def __init__(self): - self.layer = tf.layers.Dense(1) - - def predict(self, inputs): - return self.layer(inputs) -``` - -The `tf.layers` API makes it more convenient to define more sophisticated -models. For example, the following will train an MNIST model: - -```python -class MNISTModel(object): - def __init__(self, data_format): - # 'channels_first' is typically faster on GPUs - # while 'channels_last' is typically faster on CPUs. - # See: https://www.tensorflow.org/performance/performance_guide#data_formats - if data_format == 'channels_first': - self._input_shape = [-1, 1, 28, 28] - else: - self._input_shape = [-1, 28, 28, 1] - self.conv1 = tf.layers.Conv2D(32, 5, - padding='same', - activation=tf.nn.relu, - data_format=data_format) - self.max_pool2d = tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='same', data_format=data_format) - self.conv2 = tf.layers.Conv2D(64, 5, - padding='same', - activation=tf.nn.relu, - data_format=data_format) - self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu) - self.dropout = tf.layers.Dropout(0.5) - self.dense2 = tf.layers.Dense(10) - - def predict(self, inputs): - x = tf.reshape(inputs, self._input_shape) - x = self.max_pool2d(self.conv1(x)) - x = self.max_pool2d(self.conv2(x)) - x = tf.layers.flatten(x) - x = self.dropout(self.dense1(x)) - return self.dense2(x) - -def loss(model, inputs, targets): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=model.predict(inputs), labels=targets)) - - -# Load the training and validation data -from tensorflow.examples.tutorials.mnist import input_data -data = input_data.read_data_sets("./mnist_data", one_hot=True) - -# Train -device = "gpu:0" if tfe.num_gpus() else "cpu:0" -model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last') -optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) -grad = tfe.implicit_gradients(loss) -for i in range(20001): - with tf.device(device): - (inputs, targets) = data.train.next_batch(50) - optimizer.apply_gradients(grad(model, inputs, targets)) - if i % 100 == 0: - print("Step %d: Loss on training set : %f" % - (i, loss(model, inputs, targets).numpy())) -print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy()) -``` - -For a more complete example, see -[`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) - -### Checkpointing trained variables - -TensorFlow Variables (`tfe.Variable`) provides a way to represent shared, -persistent state of your model. The `tfe.Saver` class (which is a thin wrapper -over the -[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver) -class) provides a means to save and restore variables to and from _checkpoints_. - -For example: - -```python -# Create variables. -x = tfe.Variable(10., name='x') -y = tfe.Variable(5., name='y') - -# Create a Saver. -saver = tfe.Saver([x, y]) - -# Assign new values to the variables and save. -x.assign(2.) -saver.save('/tmp/ckpt') - -# Change the variable after saving. -x.assign(11.) -assert 16. == (x + y).numpy() # 11 + 5 - -# Restore the values in the checkpoint. -saver.restore('/tmp/ckpt') - -assert 7. == (x + y).numpy() # 2 + 5 -``` - -### `tfe.Network` - -You may often want to organize your models using classes, like the `MNISTModel` -class described above. We recommend inheriting from the `tfe.Network` class as -it provides conveniences like keeping track of all model variables and methods -to save and restore from checkpoints. - -Sub-classes of `tfe.Network` may register `Layer`s (like classes in -[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), -or [Keras -layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) -using a call to `self.track_layer()` and define the computation in an -implementation of `call()`. - -Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables -lazily, when the first input is encountered. - -For example, consider the following two-layer neural network: - -```python -class TwoLayerNet(tfe.Network): - def __init__(self): - super(TwoLayerNet, self).__init__() - self.layer1 = self.track_layer( - tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False)) - self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False)) - - def call(self, x): - return self.layer2(self.layer1(x)) - -net = TwoLayerNet() - -# No variables created yet -assert 0 == len(net.variables) - -# They are created on first input: -inp = tf.constant([[1.]]) - -# Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units, -# the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3 -# matrix. -assert [1, 3] == net(inp).shape.as_list() # Invoke net; get output shape. -assert 1 == len(net.layer1.variables) -assert 1 == len(net.layer2.variables) -assert 2 == len(net.variables) # weights for each layer. -assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1. -assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2. -``` - -The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows -instances of `tfe.Network` to be embedded in other networks. For example: - -```python -class ThreeLayerNet(tfe.Network): - def __init__(self): - super(ThreeLayerNet, self).__init__() - self.a = self.track_layer(TwoLayerNet()) - self.b = self.track_layer(tf.layers.Dense(4, use_bias=False)) - - def call(self, x): - return self.b(self.a(x)) - -net = ThreeLayerNet() - -assert [1, 4] == net(inp).shape.as_list() -assert 3 == len(net.variables) -assert [1, 2] == net.variables[0].shape.as_list() -assert [2, 3] == net.variables[1].shape.as_list() -assert [3, 4] == net.variables[2].shape.as_list() -``` - -See more examples in -[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples). - -`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a -convenient way to save and load checkpoints without changing the program once -the checkpoint has been created. For example, we can set an objective for the -output of our network, choose an optimizer, and a location for the checkpoint: - -```python -objective = tf.constant([[2., 3., 4., 5.]]) -optimizer = tf.train.AdamOptimizer(0.01) -checkpoint_directory = '/tmp/tfe_example' -checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') -net = ThreeLayerNet() -``` - -Note that variables have not been created yet. We want them to be restored from -a checkpoint, if one exists, so we create them inside a -`tfe.restore_variables_on_create` context manager. Then our training loop is the -same whether starting training or resuming from a previous checkpoint: - -```python -with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(checkpoint_directory)): - global_step = tf.train.get_or_create_global_step() - for _ in range(100): - loss_fn = lambda: tf.norm(net(inp) - objective) - optimizer.minimize(loss_fn, global_step=global_step) - if tf.equal(global_step % 20, 0): - print("Step %d, output %s" % (global_step.numpy(), - net(inp).numpy())) - all_variables = ( - net.variables - + optimizer.variables() - + [global_step]) - # Save the checkpoint. - tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) -``` - -The first time it runs, `Network` variables are initialized randomly. Then the -output is trained to match the objective we've set: - -``` -Step 20, output [[ 0.03575622 0.29863232 0.03474367 0.24735749]] -Step 40, output [[ 0.40646029 0.9856872 0.46851286 0.95358551]] -Step 60, output [[ 1.74541104 2.800704 1.79055595 2.74783421]] -Step 80, output [[ 2.14977384 3.44340849 3.96120024 5.16242075]] -Step 100, output [[ 1.99943113 3.02364397 3.93500996 4.9610076 ]] -``` - -In subsequent iterations, variables are initialized with the values read from -the latest checkpoint. Running the same code again, we continue from where we -left off: - -``` -Step 120, output [[ 1.99234128 3.0271616 3.98732996 4.96401167]] -Step 140, output [[ 2.00133467 3.01270437 4.00616646 5.00406504]] -Step 160, output [[ 1.99647415 2.9956708 3.99064088 4.99632359]] -Step 180, output [[ 2.00699997 3.00904822 4.00706148 5.01193142]] -Step 200, output [[ 1.98334622 2.98249531 3.97375059 4.97123432]] -``` - - -### Summaries, metrics and TensorBoard - -[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) -is a popular tool for understanding, debugging and optimizing the model training -process. To benefit from the visualizations offered by TensorBoard, summary -events need to be written during the course of execution of your program. You -might find many Tensorflow programs that include the -[`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations -during graph construction. - -`tf.summary` operations are *not* compatible with eager execution, but an -equivalent alternative exists in -[`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/summary) -that is compatible with both eager execution and graph construction. - -During model construction simply insert summary operations like -`tf.contrib.summary.scalar`. These operations do nothing by default, unless a -summary writer is currently active and a writing policy is set. - -For example, to record summaries once every 100 global steps, use: - -```python -tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_file_writer(logdir) - -for _ in range(iterations): - with writer.as_default(): - with tf.contrib.summary.record_summaries_every_n_global_steps(100): - # your model code goes here - tf.contrib.summary.scalar('loss', loss) - # ... -``` - -See the full mnist example in -[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) -for a full model using `tf.contrib.summary`. - -Similarly to summaries, the metrics in `tf.metrics` are currently not compatible -with eager execution. We instead provide object-oriented metrics in the -`tfe.metrics` package, which are compatible with graph construction as well. - -Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and -`tfe.Metrics.Accuracy`, all implement an intuitive object-oriented -interface. Here's an example of how to use the `tfe.metrics.Mean` metric: - -```python -# Metrics are objects, which can be created and destroyed. -my_mean = tfe.metrics.Mean(name='my_mean') -# While a metric is active, you can call it as a function to accumulate into its -# internal state. -my_mean(0.0) -my_mean(10.0) -# Once you've finished updating the metric, you can get its result. In this case -# a simple average over all the calls to it. If a summary writer is active the -# metric will write the appropriate summaries using the metric name. -assert 5.0 == my_mean.result().numpy() -``` - -For a full example of a model using metrics for evaluation, see the mnist -example in -[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist). - -### Input Pipelines - -The discussion above has been centered around the computation executed by your -model. The -[`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) -module provides APIs to build complex input pipelines from simple, reusable -pieces. - -If you're familiar with constructing `tf.data.Dataset` objects when building -TensorFlow graphs, the same API calls are used when eager execution is enabled. -However, the process of iterating over elements of the dataset differs between -eager execution and graph construction. When eager execution is enabled, the -discussion on iterator creation using `make_one_shot_iterator()` and -`get_next()` in the -[Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is -*not* applicable. Instead, a more Pythonic `Iterator` class is available. - -For example: - -```python -# Create a source Dataset from in-memory numpy arrays. -# For reading from files on disk, you may want to use other Dataset classes -# like the TextLineDataset or the TFRecordDataset. -dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]) - -# Apply transformations, shuffling, batching etc. -dataset = dataset.map(tf.square).shuffle(2).batch(2) - -# Use tfe.Iterator to iterate over the dataset. -for x in tfe.Iterator(dataset): - print(x) -``` - -Output: - -``` -tf.Tensor([4 9], shape=(2,), dtype=int32) -tf.Tensor([16 25], shape=(2,), dtype=int32) -tf.Tensor([36 1], shape=(2,), dtype=int32) -``` - -## Interoperating with Graphs - -Eager execution improves the process of model development in Python; however, -because it is in its earliest stages, it does not yet support some features -available to [TensorFlow -graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph) -that are desirable when deploying models in production. In particular, eager -execution does not yet support distributed training, exporting models (to other -[programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow -serving](https://www.tensorflow.org/serving/), and mobile applications), and -various memory and computation optimizations that are applied to TensorFlow's -dataflow graphs. - -That said, the APIs used to build modes are exactly the same whether executing -eagerly or constructing graphs. This means that you can iteratively develop your -model with eager execution enabled and later, if needed, use the same code to -reap the benefits of representing models as computational graphs. - -For example, -[`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) -defines a model that is eagerly executed. That same code is used to construct -and execute a graph in -[`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py). - -Other models in the [examples -directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/) -demonstrate this as well. - -Some differences worth noting: - -- There is no notion of a `tf.placeholder` or a `tf.Session` when eager - execution is enabled. -- Many properties on the `tf.Tensor` object, like `tf.Tensor.name`, - `tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution - is enabled and their use will raise an `AttributeError`. -- To use `tfe.implicit_gradients` in graph construction, variables must be - created with [`use_resource=True`] provided to - [`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable) - or - [`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope). -- Some API calls (such as the functional-style `tf.layers.dense`, - `tf.layers.conv2d`) are not compatible with eager execution. Use of such - methods should raise an error indicating the alternative (e.g., the - `tf.layers.Dense` and `tf.layers.Conv2D` classes). - -## What next? +immediately: concrete values are returned, instead of creating a computational +graph that is executed later. -Please give eager execution a spin. This feature is in early stages and is -evolving, so we welcome your feedback via issues on GitHub (see [known -issues](https://github.com/tensorflow/tensorflow/labels/comp:eager)). +A user guide is available: https://www.tensorflow.org/programmers_guide/eager +([source file](../../../../docs_src/programmers_guide/eager.md)) -You may want to browse through some sample code, including benchmarks for some: +We welcome feedback through [GitHub issues](https://github.com/tensorflow/tensorflow/labels/comp:eager). -- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) -- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) -- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) -- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) -- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) +Sample code is available, including benchmarks for some: +- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) +- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) +- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) +- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) +- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index ea8dbf2b46ea4bd0e33645ae3c590c4dd13f7a52..907f9204c2d31a652ca2a0539a23db4722b4e154 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -20,7 +20,6 @@ from __future__ import print_function import re -from tensorflow.contrib.summary import summary_ops from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes @@ -29,13 +28,14 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope - +from tensorflow.python.training import checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") -class Metric(object): +class Metric(checkpointable.CheckpointableBase): """A metric holds state for aggregating statistics over an evaluation run. Example use with eager execution: @@ -93,11 +93,12 @@ class Metric(object): `aggregate()`, it is for use by TensorFlow infrastructure. """ - def __init__(self, name=None): + def __init__(self, name=None, use_global_variables=False): self._built = False self._vars = [] self._initial_values = {} self._updates = [] + self._use_global_variables = use_global_variables name = name or self.__class__.__name__ # Replace things like spaces in name to create a valid scope name. scope_name = _to_replace.sub("_", name) @@ -108,13 +109,25 @@ class Metric(object): pos = scope.name.rfind(scope_name) self._name = name + scope.name[pos + len(scope_name):] self._scope = scope - if context.in_graph_mode(): + + # Ensures that if the user calls build directly we still set self._built to + # True to prevent variables from being recreated. + self._build = self.build + + def actual_build(*args, **kwargs): + self._build(*args, **kwargs) + self._built = True + self.build = actual_build + self.build.__doc__ = self._build.__doc__ + + # Captures construction scope for proper initialization. + if context.executing_eagerly(): + self._construction_scope = context.eager_mode + else: # We make self.call() into a graph callable here, so that we can # return a single op that performs all of the variable updates. self._construction_scope = ops.get_default_graph().as_default self.call = function.defun(self.call) - else: - self._construction_scope = context.eager_mode # ---- API for users ---- def __call__(self, *args, **kwargs): @@ -155,10 +168,11 @@ class Metric(object): initialization. Under eager execution, the variables are reset to their initial values as a side effect and this function returns None. """ - if context.in_graph_mode(): + if context.executing_eagerly(): + for v in self._vars: + v.assign(self._initial_values[v]) + else: return control_flow_ops.group([v.initializer for v in self._vars]) - for v in self._vars: - v.assign(self._initial_values[v]) # ---- To be implemented by descendants --- def build(self, *args, **kwargs): @@ -200,10 +214,10 @@ class Metric(object): def value(self): """In graph mode returns the result Tensor while in eager the callable.""" - if context.in_graph_mode(): - return self.result() - else: + if context.executing_eagerly(): return self.result + else: + return self.result() # We can support two different strategies of for doing data-parallel # distributed metric computations: @@ -245,19 +259,31 @@ class Metric(object): """***Only for use by descendants of Metric***.""" if self._built: raise RuntimeError("Can't call add_variable() except in build().") - collections = None if context.in_eager_mode() else [ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ] - v = variable_scope.get_variable( - name, - shape, - dtype, - initializer, + if context.executing_eagerly(): + collections = None + else: + if self._use_global_variables: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + else: + collections = [ops.GraphKeys.LOCAL_VARIABLES] + collections += [ops.GraphKeys.METRIC_VARIABLES] + # Variables are Checkpointable dependencies of Metrics regardless of the + # global/local distinction. Users can avoid saving variables by not adding a + # dependency on the Metric. + v = self._add_variable_with_custom_getter( + name=name, + shape=shape, + dtype=dtype, + initializer=initializer, trainable=False, collections=collections, - use_resource=True) + use_resource=True, + getter=variable_scope.get_variable, + # Raise duplicate variable exceptions from get_variable rather than + # Checkpointable. + overwrite=True) self._vars.append(v) - if context.in_eager_mode(): + if context.executing_eagerly(): self._initial_values[v] = v.value() return v @@ -267,8 +293,10 @@ class Mean(Metric): # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64? - def __init__(self, name=None, dtype=dtypes.float64): - super(Mean, self).__init__(name=name) + def __init__(self, name=None, dtype=dtypes.float64, + use_global_variables=False): + super(Mean, self).__init__(name=name, + use_global_variables=use_global_variables) self.dtype = dtype def build(self, *args, **kwargs): diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index a9ecaa3f8bced3043ea0eb0ac3aa8bfa65e9e1ff..f0fe4ce8c53bb80c03a3f0de37078bcdb975a0b4 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile from tensorflow.contrib.eager.python import metrics -from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -29,6 +29,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import summary_ops_v2 as summary_ops +from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util @@ -50,6 +52,19 @@ class MetricsTest(test.TestCase): self.assertEqual( set(m.variables), set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))) + self.assertEqual(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES), []) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testUseGlobalVariablesCollections(self): + with context.graph_mode(), ops.Graph().as_default(): + m = metrics.Mean(use_global_variables=True) + m(1000) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + self.assertEqual(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES), []) self.assertEqual( set(m.variables), set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) @@ -180,6 +195,15 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testBuildMean(self): + # Verify that calling build() on Mean and then calling it won't recreate + # variables. + m = metrics.Mean() + m.build() + old_numer = m.numer + m(0.0) + self.assertTrue(old_numer is m.numer) + def testMetricsChain(self): with context.graph_mode(), self.test_session(): m1 = metrics.Mean() @@ -193,6 +217,31 @@ class MetricsTest(test.TestCase): self.assertAllEqual(m2.result().eval(), 2.0) self.assertAllEqual(m1.result().eval(), 1.0) + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + mean = metrics.Mean() + checkpoint = checkpointable_utils.Checkpoint(mean=mean) + mean.build() + mean._built = True + self.evaluate(mean.init_variables()) + self.evaluate(mean(100.)) + self.evaluate(mean(200.)) + save_path = checkpoint.save(checkpoint_prefix) + self.evaluate(mean(1000.)) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.evaluate(mean(300.)) + self.assertAllEqual(200., self.evaluate(mean.value())) + + restore_mean = metrics.Mean() + restore_checkpoint = checkpointable_utils.Checkpoint(mean=restore_mean) + status = restore_checkpoint.restore(save_path) + restore_update = restore_mean(300.) + status.assert_consumed().run_restore_ops() + self.evaluate(restore_update) + self.assertAllEqual(200., self.evaluate(restore_mean.value())) + self.assertEqual(3, self.evaluate(restore_mean.denom)) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d..2f8721324f5fc12565d047a64af22b8df215a92b 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -25,6 +25,7 @@ import weakref from tensorflow.python.eager import context from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops +from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpoint_utils @@ -149,7 +150,7 @@ class Network(base.Layer): # check we might have name collisions if the parent scope on init gets # closed before build is called. self._variable_scope_counts_on_init = ( - variable_scope._get_default_variable_store().variable_scopes_count) + variable_scope.get_variable_scope_store().variable_scopes_count) def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" @@ -176,7 +177,7 @@ class Network(base.Layer): avoid_names = parent_network._owned_layers name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = base._get_default_graph_uid_map() + name_uid_map = keras_base_layer.get_default_graph_uid_map() # Figure out which names we have to avoid based on which variable scope # we're nested in. strip_name = self._default_parent_variable_scope.name @@ -326,6 +327,8 @@ class Network(base.Layer): raise TypeError( "Network.track_layer() passed type %s, not a tf.layers.Layer" % (type(layer),)) + # Always use `ResourceVariable` with legacy layers. + layer._use_resource_variables = True if isinstance(layer, Network): layer._finalize_name(parent_network=self) else: @@ -639,7 +642,7 @@ def _make_custom_getter_for_deferred_restorations(): # Mark as already restored from this checkpoint. delayed_restoration.checkpointed_variables_to_restore[ checkpoint_name] = None - if context.in_graph_mode(): + if not context.executing_eagerly(): delayed_restoration.session.run(variable.initializer) if found_value: # Error checking should run even if we've already restored a value. @@ -772,7 +775,7 @@ def save_network_checkpoint( variable_map[mapped_name]._shared_name, variable._shared_name, network.scope_name)) - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -853,7 +856,7 @@ def _restore_existing_variables(network, save_path, map_func, user_map_func): network_name=network.name, network_scope_name=network.scope_name)) if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() @@ -880,7 +883,7 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, # _DeferredRestoration objects once a Network has been built (so that # restoring in a loop does not take increasing amounts of memory). if checkpointed_variables_to_restore: - if context.in_eager_mode(): + if context.executing_eagerly(): sess = None else: sess = ops.get_default_session() diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 3329fc6c513265deff41a368f5688dd605209c14..f43376d5d777a7f17d975e07b746f7b1c731e8ea 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -20,12 +20,10 @@ import gc from tensorflow.contrib.eager.python import network from tensorflow.contrib.layers.python.layers import regularizers -from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl -from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import math_ops @@ -469,36 +467,6 @@ class NetworkTest(test.TestCase): self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) - def testGraphOpNames(self): - """Network operation names should match variable naming.""" - - def _check_op_prefixes(expected_prefix, checked_ops): - for operation in ops.get_default_graph().get_operations(): - if operation.name == "ignore": - continue - if operation.name in checked_ops: - continue - checked_ops.add(operation.name) - self.assertStartsWith(expected_start=expected_prefix, - actual=operation.name) - self.assertNotIn("my_network", operation.name[len(expected_prefix):]) - self.assertNotIn("dense", operation.name[len(expected_prefix):]) - - with context.graph_mode(): - net = MyNetwork() - zero = constant_op.constant([[0.]], name="ignore") - net(zero) - checked_ops = set() - _check_op_prefixes(expected_prefix="my_network/dense/", - checked_ops=checked_ops) - net.net2 = net.track_layer(MyNetwork()) - net.net2(zero) - _check_op_prefixes(expected_prefix="my_network/my_network/dense/", - checked_ops=checked_ops) - MyNetwork()(zero) - _check_op_prefixes(expected_prefix="my_network_1/dense/", - checked_ops=checked_ops) - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableRegularizers(self): net = RegularizedNetwork() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 62421849c766a1124c726812428985c913c653a3..fdaca90fd13576e6ca8a3408aaf528dbc2384b0c 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -73,7 +73,7 @@ def restore_variables_on_create(save_path, map_func=None): NotFoundError: If the variable is not found in checkpoint. ValueError: If not used in eager mode or map_func is not callable. """ - if context.in_graph_mode(): + if not context.executing_eagerly(): raise ValueError( "Currently, restore_variables_on_create can only be used with " "eager execution enabled.") @@ -131,7 +131,7 @@ class Saver(object): Raises: RuntimeError: if invoked when eager execution has not been enabled. """ - if context.in_graph_mode(): + if not context.executing_eagerly(): raise RuntimeError("tfe.Saver can only be used when eager " "execution is enabled. Use tf.train.Saver when " "building graphs.") diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 1a7f7b85e688e80e3cf482f2754462888187d311..4032e755f6e7dea9dcb42587f14e8386e5db2338 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -102,7 +102,6 @@ class SaverTest(test.TestCase): # Can still restore it. saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) - self.assertEqual(v1.read_value().numpy(), 1.0) # However, cannot restore it with default name. with self.assertRaisesOpError('not found in checkpoint'): saver = _saver.Saver([v1, v2]).restore(ckpt_prefix) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index d32bebf90c1e768d1efec26b3b78bf1a522a8f00..79dd117854e5fe9f066f671d8ce62e08579e0ed9 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -56,14 +56,24 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@save_network_checkpoint @@restore_network_checkpoint +@@Checkpoint +@@Checkpointable +@@CheckpointableSaver + +@@executing_eagerly @@in_eager_mode -@@in_graph_mode +@@set_execution_mode +@@execution_mode +@@async_wait +@@async_clear_error @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT +@@SYNC +@@ASYNC """ from __future__ import absolute_import @@ -87,11 +97,15 @@ from tensorflow.python.eager import function from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT -from tensorflow.python.eager.context import in_eager_mode -from tensorflow.python.eager.context import in_graph_mode +from tensorflow.python.eager.context import executing_eagerly from tensorflow.python.eager.context import list_devices +from tensorflow.python.eager.context import set_execution_mode +from tensorflow.python.eager.context import execution_mode +from tensorflow.python.eager.context import async_wait +from tensorflow.python.eager.context import async_clear_error +from tensorflow.python.eager.context import SYNC +from tensorflow.python.eager.context import ASYNC from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback @@ -101,10 +115,14 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template +from tensorflow.python.training.checkpointable import Checkpointable +from tensorflow.python.training.checkpointable_utils import CheckpointableSaver +from tensorflow.python.training.checkpointable_utils import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func @@ -115,5 +133,6 @@ implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function value_and_gradients_function = backprop.val_and_grad_function GradientTape = backprop.GradientTape # pylint: disable=invalid-name +in_eager_mode = executing_eagerly remove_undocumented(__name__) diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index b6659c2a1797feab261d756e78b45231dbea5a02..e80ccbb74d8623e977a98cb7fa5eb41f3c9bf250 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -47,7 +47,8 @@ class TFETest(test_util.TensorFlowTestCase): def testVariableError(self): with self.assertRaisesRegexp( - RuntimeError, r'Variable not supported in Eager mode'): + RuntimeError, + r'Variable not supported when eager execution is enabled'): variables.Variable(initial_value=1.0) def testGradients(self): diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 6cdbed5b896577f5622b1bd0123c289c798bc0a5..9f4cd44afbede286966ba0e7357c5dac92a2b729 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -9,23 +9,12 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "estimator_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":boosted_trees", ":dnn", ":dnn_linear_combined", ":extenders", @@ -34,10 +23,41 @@ py_library( ":logit_fns", ":multi_head", ":replicate_model_fn", + ":rnn", "//tensorflow/python:util", ], ) +py_library( + name = "boosted_trees", + srcs = ["python/estimator/boosted_trees.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:boosted_trees", + ], +) + +py_test( + name = "boosted_trees_test", + size = "medium", + srcs = ["python/estimator/boosted_trees_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":boosted_trees", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], +) + py_library( name = "dnn", srcs = ["python/estimator/dnn.py"], @@ -70,6 +90,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -110,6 +131,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -138,9 +160,11 @@ py_test( size = "medium", srcs = ["python/estimator/extenders_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 deps = [ ":extenders", "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/predictor", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", @@ -169,9 +193,11 @@ py_library( "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", + "//tensorflow/python:nn", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", @@ -191,6 +217,7 @@ py_test( ":head", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -242,6 +269,7 @@ py_test( "//tensorflow/python/estimator:numpy_io", "//tensorflow/python/estimator:prediction_keys", "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -288,6 +316,8 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:metrics", "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", @@ -351,6 +381,7 @@ cuda_py_test( size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ + "@absl_py//absl/testing:parameterized", "//tensorflow/python/estimator", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:export_export", @@ -382,3 +413,57 @@ cuda_py_test( "notap", ], ) + +py_library( + name = "rnn", + srcs = ["python/estimator/rnn.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + "//tensorflow/contrib/feature_column:feature_column_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + +py_test( + name = "rnn_test", + size = "medium", + srcs = ["python/estimator/rnn_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":rnn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 0f75b77050b0ba4c752a6a74fdc7024170b6f318..be20d1b7770d3f3df21ac9c0f811d924bf4152ee 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.extenders import * @@ -27,6 +28,7 @@ from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * +from tensorflow.contrib.estimator.python.estimator.rnn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -39,15 +41,19 @@ _allowed_symbols = [ 'multi_class_head', 'multi_head', 'multi_label_head', + 'poisson_regression_head', 'regression_head', 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', + 'boosted_trees_classifier_train_in_memory', + 'boosted_trees_regressor_train_in_memory', 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', 'replicate_model_fn', 'TowerOptimizer', + 'RNNClassifier', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..bd641014e9eec6623d66574bccd08ff03ebc28ac --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -0,0 +1,357 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Boosted Trees estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees + + +def _validate_input_fn_and_repeat_dataset(train_input_fn): + """Validates whether the input_fn is valid, and repeat() if tf.Dataset.""" + def _input_fn(): + result_input_fn = train_input_fn() + if isinstance(result_input_fn, dataset_ops.Dataset): + return result_input_fn.repeat() + return result_input_fn + + return _input_fn + + +class _BoostedTreesEstimator(estimator.Estimator): + """An Estimator for Tensorflow Boosted Trees models.""" + + def __init__(self, + feature_columns, + n_batches_per_layer, + head, + model_dir=None, + weight_column=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + min_node_weight=0., + config=None): + """Initializes a `BoostedTreesEstimator` instance. + + Args: + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + n_batches_per_layer: the number of batches to collect statistics per + layer. + head: the `Head` instance defined for Estimator. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). + config: `RunConfig` object to configure the runtime settings. + """ + # pylint:disable=protected-access + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, labels, mode, head, feature_columns, tree_hparams, + n_batches_per_layer, config) + + super(_BoostedTreesEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + # pylint:enable=protected-access + + +def boosted_trees_classifier_train_in_memory( + train_input_fn, + feature_columns, + model_dir=None, + n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT, + weight_column=None, + label_vocabulary=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + min_node_weight=0., + config=None, + train_hooks=None): + """Trains a boosted tree classifier with in memory dataset. + + Example: + + ```python + bucketized_feature_1 = bucketized_column( + numeric_column('feature_1'), BUCKET_BOUNDARIES_1) + bucketized_feature_2 = bucketized_column( + numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + + def train_input_fn(): + dataset = create-dataset-from-training-data + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. + return dataset + + classifier = boosted_trees_classifier_train_in_memory( + train_input_fn, + feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_trees=100, + ... + ) + + def input_fn_eval(): + ... + return dataset + + metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10) + ``` + + Args: + train_input_fn: the input function returns a dataset containing a single + epoch of *unbatched* features and labels. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + n_classes: number of label classes. Default is binary classification. + Multiclass support is not yet implemented. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). + config: `RunConfig` object to configure the runtime settings. + train_hooks: a list of Hook instances to be passed to estimator.train(). + + Returns: + a `BoostedTreesClassifier` instance created with the given arguments and + trained with the data loaded up on memory from the input_fn. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + # pylint: disable=protected-access + # TODO(nponomareva): Support multi-class cases. + if n_classes == canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT: + n_classes = 2 + head, closed_form = ( + canned_boosted_trees._create_classification_head_and_closed_form( + n_classes, weight_column, label_vocabulary=label_vocabulary)) + + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer=1, + config=config, + closed_form_grad_and_hess_fn=closed_form, + train_in_memory=True) + + in_memory_classifier = estimator.Estimator( + model_fn=_model_fn, model_dir=model_dir, config=config) + + in_memory_classifier.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) + + return in_memory_classifier + # pylint: enable=protected-access + + +def boosted_trees_regressor_train_in_memory( + train_input_fn, + feature_columns, + model_dir=None, + label_dimension=canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT, + weight_column=None, + n_trees=100, + max_depth=6, + learning_rate=0.1, + l1_regularization=0., + l2_regularization=0., + tree_complexity=0., + min_node_weight=0., + config=None, + train_hooks=None): + """Trains a boosted tree regressor with in memory dataset. + + Example: + + ```python + bucketized_feature_1 = bucketized_column( + numeric_column('feature_1'), BUCKET_BOUNDARIES_1) + bucketized_feature_2 = bucketized_column( + numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + + def train_input_fn(): + dataset = create-dataset-from-training-data + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. + return dataset + + regressor = boosted_trees_regressor_train_in_memory( + train_input_fn, + feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_trees=100, + ... + ) + + def input_fn_eval(): + ... + return dataset + + metrics = regressor.evaluate(input_fn=input_fn_eval, steps=10) + ``` + + Args: + train_input_fn: the input function returns a dataset containing a single + epoch of *unbatched* features and labels. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + label_dimension: Number of regression targets per example. + Multi-dimensional support is not yet implemented. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to downweight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + n_trees: number trees to be created. + max_depth: maximum depth of the tree to grow. + learning_rate: shrinkage parameter to be used when a tree added to the + model. + l1_regularization: regularization multiplier applied to the absolute + weights of the tree leafs. + l2_regularization: regularization multiplier applied to the square weights + of the tree leafs. + tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). + config: `RunConfig` object to configure the runtime settings. + train_hooks: a list of Hook instances to be passed to estimator.train(). + + Returns: + a `BoostedTreesClassifier` instance created with the given arguments and + trained with the data loaded up on memory from the input_fn. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + # pylint: disable=protected-access + # TODO(nponomareva): Extend it to multi-dimension cases. + if label_dimension == canned_boosted_trees._HOLD_FOR_MULTI_DIM_SUPPORT: + label_dimension = 1 + head = canned_boosted_trees._create_regression_head(label_dimension, + weight_column) + + # HParams for the model. + tree_hparams = canned_boosted_trees._TreeHParams( + n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, + tree_complexity, min_node_weight) + + def _model_fn(features, labels, mode, config): + return canned_boosted_trees._bt_model_fn( + features, + labels, + mode, + head, + feature_columns, + tree_hparams, + n_batches_per_layer=1, + config=config, + train_in_memory=True) + + in_memory_regressor = estimator.Estimator( + model_fn=_model_fn, model_dir=model_dir, config=config) + + in_memory_regressor.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) + + return in_memory_regressor + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py new file mode 100644 index 0000000000000000000000000000000000000000..76cbefe5e94502188388df6fc2816d130ac896d5 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -0,0 +1,222 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests boosted_trees estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.estimator.python.estimator import boosted_trees +from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest +from tensorflow.python.training import checkpoint_utils + +NUM_FEATURES = 3 + +BUCKET_BOUNDARIES = [-2., .5, 12.] # Boundaries for all the features. +INPUT_FEATURES = np.array( + [ + [12.5, 1.0, -2.001, -2.0001, -1.999], # feature_0 quantized:[3,2,0,0,1] + [2.0, -3.0, 0.5, 0.0, 0.4995], # feature_1 quantized:[2,0,2,1,1] + [3.0, 20.0, 50.0, -100.0, 102.75], # feature_2 quantized:[2,3,3,0,3] + ], + dtype=np.float32) +CLASSIFICATION_LABELS = [[0.], [1.], [1.], [0.], [0.]] +REGRESSION_LABELS = [[1.5], [0.3], [0.2], [2.], [5.]] +FEATURES_DICT = {'f_%d' % i: INPUT_FEATURES[i] for i in range(NUM_FEATURES)} + + +def _make_train_input_fn(is_classification): + """Makes train input_fn for classification/regression.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + return features_dict, labels + + return _input_fn + + +def _make_train_input_fn_dataset(is_classification): + """Makes input_fn using Dataset.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(features_dict), + dataset_ops.Dataset.from_tensors(labels) + )) + return ds + + return _input_fn + + +class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._head = canned_boosted_trees._create_regression_head(label_dimension=1) + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) + for i in range(NUM_FEATURES) + } + + def _assert_checkpoint(self, model_dir, global_step, finalized_trees, + attempted_layers): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + serialized = reader.get_tensor('boosted_trees:0_serialized') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + self.assertEqual( + finalized_trees, + sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized])) + self.assertEqual(attempted_layers, + ensemble_proto.growing_metadata.num_layers_attempted) + + def testTrainAndEvaluateEstimator(self): + input_fn = _make_train_input_fn(is_classification=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + head=self._head, + max_depth=5) + + # It will stop after 10 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10) + eval_res = est.evaluate(input_fn=input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 1.008551) + + def testInferEstimator(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + head=self._head) + + # It will stop after 5 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(train_input_fn, steps=num_steps) + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self): + train_input_fn = _make_train_input_fn(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + # Validate predictions. + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + + def testBinaryClassifierTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + + def testRegressorTrainInMemoryAndEvalAndInfer(self): + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testRegressorTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py index b5e4d34dc70ccaa4806ae8b8ed5001bd971ee7b4..dd009a6753f3231638f93e50fc8f19eae8820139 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -34,6 +34,7 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -52,7 +53,9 @@ def _dnn_only_estimator_fn( config=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, dnn_feature_columns=feature_columns, dnn_optimizer=optimizer, @@ -100,7 +103,9 @@ def _linear_only_estimator_fn( partitioner=None): return dnn_linear_combined.DNNLinearCombinedEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), model_dir=model_dir, linear_feature_columns=feature_columns, linear_optimizer=optimizer, diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 71f810acec856d42d389260e7b9fea32123348b4..75e3107670d658e55ce23d983e47311f1c180104 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -41,7 +42,9 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index c99bf8badb35e6fffb7cae8761db9d402b8b3a8f..201699ed775f701bc9f215fff11a688175d51645 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops @@ -33,7 +34,7 @@ _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) def add_metrics(estimator, metric_fn): - """Creates a new ${tf.estimator.Estimator} which has given metrics. + """Creates a new @{tf.estimator.Estimator} which has given metrics. Example: @@ -60,7 +61,7 @@ def add_metrics(estimator, metric_fn): ``` Args: - estimator: A ${tf.estimator.Estimator} object. + estimator: A @{tf.estimator.Estimator} object. metric_fn: A function which should obey the following signature: - Args: can only have following four arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given @@ -78,7 +79,7 @@ def add_metrics(estimator, metric_fn): function, namely a `(metric_tensor, update_op)` tuple. Returns: - A new ${tf.estimator.Estimator} which has a union of original metrics with + A new @{tf.estimator.Estimator} which has a union of original metrics with given ones. """ _verify_metric_fn_args(metric_fn) @@ -96,7 +97,10 @@ def add_metrics(estimator, metric_fn): return estimator_lib.Estimator( model_fn=new_model_fn, model_dir=estimator.model_dir, - config=estimator.config) + config=estimator.config, + # pylint: disable=protected-access + warm_start_from=estimator._warm_start_settings) + # pylint: enable=protected-access def clip_gradients_by_norm(optimizer, clip_norm): @@ -161,14 +165,14 @@ def forward_features(estimator, keys=None): ``` Args: - estimator: A ${tf.estimator.Estimator} object. + estimator: A @{tf.estimator.Estimator} object. keys: a `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict` is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list` of strings, all the given `keys` are forwarded. Returns: - A new ${tf.estimator.Estimator} which forwards features to predictions. + A new @{tf.estimator.Estimator} which forwards features to predictions. Raises: ValueError: @@ -233,7 +237,17 @@ def forward_features(estimator, keys=None): 'argument of forward_features to filter unwanted features. Type of ' 'features[{}] is {}.'.format(key, key, type(feature))) predictions[key] = feature - return spec._replace(predictions=predictions) + spec = spec._replace(predictions=predictions) + if spec.export_outputs: + for ekey in ['predict', 'serving_default']: + if (ekey in spec.export_outputs and + isinstance(spec.export_outputs[ekey], + PredictOutput)): + export_outputs = spec.export_outputs[ekey].outputs + for key in get_keys(features): + export_outputs[key] = predictions[key] + + return spec return estimator_lib.Estimator( model_fn=new_model_fn, diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index ad1a8ef152b07ecbab33d9eb3184a2ae89def27d..407af2deaf0928361a4f0b0e44e842b7750118cb 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -18,20 +18,27 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import tempfile import numpy as np from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.predictor import from_saved_model from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.training import training +from tensorflow.python.util import compat def get_input_fn(x, y): @@ -177,6 +184,44 @@ class ForwardFeaturesTest(test.TestCase): self.assertIn('id', predictions) self.assertEqual(101, predictions['id']) + def test_forward_in_exported(self): + + def serving_input_fn(): + features_ph = { + 'x': array_ops.placeholder(dtypes.float32, [None]), + 'id': array_ops.placeholder(dtypes.int32, [None]) + } + features = { + key: array_ops.expand_dims(tensor, -1) + for key, tensor in features_ph.items() + } + return estimator_lib.export.ServingInputReceiver(features, features_ph) + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + # create estimator + feature_columns = [fc.numeric_column('x')] + estimator = linear.LinearRegressor(feature_columns) + estimator.train(input_fn=input_fn, steps=1) + estimator = extenders.forward_features(estimator, 'id') + + # export saved model + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn) + self.assertTrue(gfile.Exists(export_dir)) + + # restore model + predict_fn = from_saved_model(export_dir, signature_def_key='predict') + predictions = predict_fn({'x': [3], 'id': [101]}) + + # verify that 'id' exists in predictions + self.assertIn('id', predictions) + self.assertEqual(101, predictions['id']) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + def test_forward_list(self): def input_fn(): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 238cf287b768eee28b20202084eb244c085c8b75..3dcf0374c8a12b5907fbaf20d1ad72211a45ab5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -31,10 +31,12 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import nn from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -42,7 +44,7 @@ _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for multi class classification. @@ -83,7 +85,8 @@ def multi_class_head(n_classes, have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -108,7 +111,7 @@ def binary_classification_head( weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for single label binary classification. @@ -152,7 +155,8 @@ def binary_classification_head( `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -175,8 +179,9 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + inverse_link_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -195,10 +200,16 @@ def regression_head(weight_column=None, `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. - Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, label_dimension]`. + Also supports custom `inverse_link_fn`, also known as 'mean function'. + `inverse_link_fn` takes `logits` as argument and returns predicted values. + This function is the inverse of the link function defined in + https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function + Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -208,8 +219,12 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. - loss_fn: Optional loss function. + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + loss_fn: Optional loss function. Defaults to `mean_squared_error`. + inverse_link_fn: Optional inverse link function, also known as 'mean + function'. Defaults to identity. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -224,6 +239,69 @@ def regression_head(weight_column=None, label_dimension=label_dimension, loss_reduction=loss_reduction, loss_fn=loss_fn, + inverse_link_fn=inverse_link_fn, + name=name) + + +def poisson_regression_head( + weight_column=None, + label_dimension=1, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, + compute_full_loss=True, + name=None): + """Creates a `_Head` for poisson regression using `tf.nn.log_poisson_loss`. + + The loss is the weighted sum over all input dimensions. Namely, if the input + labels have shape `[batch_size, label_dimension]`, the loss is the weighted + sum over both `batch_size` and `label_dimension`. + + The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. + In many applications, the shape is `[batch_size, label_dimension]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape + `[D0, D1, ... DN]` is also supported. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or + `[D0, D1, ... DN, label_dimension]`. + + This is implemented as a generalized linear model, see + https://en.wikipedia.org/wiki/Generalized_linear_model. + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + compute_full_loss: Whether to include the constant `log(z!)` term in + computing the poisson loss. See `tf.nn.log_poisson_loss` for the full + documentation. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. + + Returns: + An instance of `_Head` for poisson regression. + + Raises: + ValueError: If `label_dimension` or `loss_reduction` is invalid. + """ + def _poisson_loss(labels, logits): + return nn.log_poisson_loss( + targets=labels, log_input=logits, compute_full_loss=compute_full_loss) + return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=label_dimension, + loss_reduction=loss_reduction, + loss_fn=_poisson_loss, + inverse_link_fn=math_ops.exp, name=name) @@ -231,7 +309,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -282,7 +360,8 @@ def multi_label_head(n_classes, string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to - reduce training loss over batch. Defaults to `SUM`. + reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely + weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -331,7 +410,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, - loss_reduction=losses.Reduction.SUM, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, name=None): self._n_classes = n_classes @@ -406,7 +485,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access reduction=losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( - unweighted_loss, axis=-1, keep_dims=True) + unweighted_loss, axis=-1, keepdims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( @@ -418,8 +497,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access processed_labels=processed_labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None, - regularization_losses=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None, regularization_losses=None): """Returns an `EstimatorSpec`. Args: @@ -431,8 +510,11 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when `mode` equals `TRAIN` or `EVAL`. + optimizer: `Optimizer` instance to optimize the loss in TRAIN mode. + Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which + updates variables and increments `global_step`. train_op_fn: Function that takes a scalar loss `Tensor` and returns - `train_op`. Required in TRAIN mode. + `train_op`. Used if `optimizer` is `None`. regularization_losses: A list of additional scalar losses to be added to the training loss, such as regularization losses. These losses are usually expressed as a batch average, so for best results users need to @@ -442,7 +524,8 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access Returns: `EstimatorSpec`. Raises: - ValueError: If `train_op_fn` is `None` in TRAIN mode. + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode, or if both are set. """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -494,8 +577,16 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access regularization_loss=regularization_loss)) # Train. - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + regularized_training_loss, + global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(regularized_training_loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: @@ -521,7 +612,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, - train_op=train_op_fn(regularized_training_loss)) + train_op=train_op) def _eval_metric_ops( self, labels, probabilities, weights, unreduced_loss, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 1411635228457218578c0297d4d901e9c86ca91a..98962ca4277a3e8fbbdb3fb2d26df9acc45168b5 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops @@ -271,9 +272,9 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - expected_training_loss = np.sum( + # loss = (labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / 2 + expected_training_loss = 0.5 * np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -297,7 +298,7 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_training_loss = np.sum( + expected_training_loss = 0.5 * np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -360,7 +361,7 @@ class MultiLabelHead(test.TestCase): labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_training_loss.eval()) + self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -437,16 +438,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -467,18 +469,17 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -509,7 +510,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -532,18 +533,17 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, } self._test_eval( head=head, @@ -561,19 +561,18 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) - ) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, + keys.LOSS_MEAN: expected_loss, # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.5972, + keys.AUC_PR: 0.7639, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3., keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3., @@ -602,8 +601,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2 + expected_loss = 12.5 spec = head.create_estimator_spec( features={ @@ -616,12 +616,12 @@ class MultiLabelHead(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { - # Average loss over weighted examples. - keys.LOSS_MEAN: expected_loss / 3, + # Average loss over weighted examples (denominator is sum(weights)). + keys.LOSS_MEAN: expected_loss * (2. / 3.), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, - keys.AUC_PR: 0.5833, + keys.AUC_PR: 0.7833, } # Assert spec contains expected tensors. @@ -662,7 +662,7 @@ class MultiLabelHead(test.TestCase): # (1 - labels) * (logits > 0) * logits expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] expected_weights = [[1.], [2.]] - expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), @@ -808,11 +808,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol) def test_train(self): head = head_lib.multi_label_head(n_classes=2) @@ -822,8 +819,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -839,8 +837,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) @@ -857,11 +856,49 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_optimizer(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over examples, divide by batch_size. + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) @@ -915,8 +952,9 @@ class MultiLabelHead(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, weighted sum over examples. - expected_loss = 25. + # Average over classes, weighted sum over examples, divide by batch_size. + # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2 + expected_loss = 12.5 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -950,11 +988,8 @@ class MultiLabelHead(test.TestCase): self.assertEqual( six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) - _assert_simple_summaries(self, { - metric_keys.MetricKeys.LOSS: expected_loss, - # Average loss over weighted examples. - metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, - }, summary_str, tol) + _assert_simple_summaries( + self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol) def test_multi_dim_weighted_train_create_loss(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" @@ -971,8 +1006,8 @@ class MultiLabelHead(test.TestCase): expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] # weights are reshaped to [2, 2, 1] to match logits. expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_training_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_training_loss = 9.9167 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, @@ -998,8 +1033,8 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -1087,15 +1122,15 @@ class MultiLabelHead(test.TestCase): weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 # = [[20/3, 10/3], [4, 8]] - # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_loss = 39.6667 + # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167 + expected_loss = 9.9167 keys = metric_keys.MetricKeys expected_metrics = { - keys.LOSS_MEAN: expected_loss / np.sum(weights), + keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)), # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, - keys.AUC_PR: 0.4037, + keys.AUC_PR: 0.6645, } self._test_eval( head=head, @@ -1106,5 +1141,75 @@ class MultiLabelHead(test.TestCase): expected_metrics=expected_metrics) +class PoissonRegressionHead(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def test_train(self): + head = head_lib.poisson_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[1], [2], [3]], dtype=np.int32) + # With x = exp(logits), z = labels. + # loss = -ln(exp(-x) * (x^z) / z!) + # = x - z * ln(x) + ln(z!) + # = exp(logits) - labels * logits - ln(labels!) + # But for ln(z!) and z > 1, the Stirling approximation is used + # ln(z!) = z*ln(z) - z + 0.5*ln(2*pi*z) + # loss = [exp(0) - 1 * 0 + ln(1!), + # exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2), + # exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)] + # = [1.0, 3.020, 1.482] + # training_loss = (1.0 + 3.020 + 1.482) / 3 + expected_loss = 1.834 + atol = 0.001 + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_near( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + atol=atol, name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run([spec.loss, spec.train_op]) + self.assertAlmostEqual(expected_loss, loss, delta=atol) + self.assertEqual(expected_train_result, train_result) + + def test_predict(self): + head = head_lib.poisson_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + expected_predictions = np.exp(logits) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py index c63514eb688af48577f0a3b7ce9e7478309f2c30..c41996b9c6871d294f157411662f2eb9d4c09e5c 100644 --- a/tensorflow/contrib/estimator/python/estimator/linear_test.py +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -32,6 +32,7 @@ from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -42,7 +43,9 @@ def _linear_estimator_fn( """Returns a LinearEstimator that uses regression_head.""" return linear.LinearEstimator( head=head_lib.regression_head( - weight_column=weight_column, label_dimension=label_dimension), + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), *args, **kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 0346ddc24bffd61068177f4622bd03be4acd53d9..ce758992140d43529037b14cbbf958d5aa763fb4 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -23,6 +23,7 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -30,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +from tensorflow.python.training import training_util _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -226,8 +228,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access weights=example_weights_by_head, processed_labels=labels_by_head) + # TODO(b/65403806): Support regularization_losses arg. def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): + self, features, mode, logits, labels=None, optimizer=None, + train_op_fn=None): """See `_Head`.""" if isinstance(logits, dict): logits_dict = logits @@ -248,9 +252,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access train_op_fn=_no_op_train_fn)) if mode == model_fn.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None in TRAIN mode.') - spec = self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train( + all_estimator_spec=all_estimator_spec, + optimizer=optimizer, + train_op_fn=train_op_fn) with ops.name_scope(''): summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) return spec @@ -279,16 +284,21 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access begin_idx += head.logits_dimension return logits_dict - def _merge_train(self, all_estimator_spec, train_op_fn): + def _merge_train(self, all_estimator_spec, optimizer, train_op_fn): """Merges list of `EstimatorSpec` for training. Args: all_estimator_spec: list of `EstimatorSpec` for the individual heads. - train_op_fn: Function to create train op. See `create_estimator_spec` - documentation for more details. + optimizer: `Optimizer` instance to create train op. See + `create_estimator_spec` documentation for more details. + train_op_fn: Function to create train op. Used if `optimizer` is `None`. Returns: `EstimatorSpec` that merges all heads for TRAIN. + + Raises: + ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN + mode. """ losses = [] metrics = {} @@ -297,11 +307,20 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access # Metric keys already contain head.name. metrics.update(spec.eval_metric_ops or {}) loss = _merge_losses(losses, self._head_weights) + if optimizer is not None: + if train_op_fn is not None: + raise ValueError('train_op_fn and optimizer cannot both be set.') + train_op = optimizer.minimize( + loss, global_step=training_util.get_global_step()) + elif train_op_fn is not None: + train_op = train_op_fn(loss) + else: + raise ValueError('train_op_fn and optimizer cannot both be None.') return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, loss=loss, - train_op=train_op_fn(loss), + train_op=train_op, eval_metric_ops=metrics) def _merge_predict(self, all_estimator_spec): @@ -319,16 +338,24 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access all_estimator_spec[0].export_outputs, self._heads[0].name), } + merged_predict_outputs = {} for head, spec in zip(self._heads, all_estimator_spec): head_name = head.name for k, v in six.iteritems(spec.export_outputs): if k == _DEFAULT_SERVING_KEY: key = head_name else: - key = '%s/%s' % (k, head_name) + key = '%s/%s' % (head_name, k) export_outputs[key] = v + if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access + isinstance(v, export_output_lib.PredictOutput)): + for kp, vp in six.iteritems(v.outputs): + key = '%s/%s' % (head_name, kp) + merged_predict_outputs[key] = vp for k, v in six.iteritems(spec.predictions): predictions[(head_name, k)] = v + export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access + export_output_lib.PredictOutput(merged_predict_outputs)) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index e47a6788f3b5440c4906b9f0430c802cf73237e3..3d6fccb1180c435f64552667306be004437f62ba 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', + 'head1/predict', 'head2', 'head2/classification', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -158,6 +158,22 @@ class MultiHeadTest(test.TestCase): self.assertAllClose( expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['predict'].outputs['head1/probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['predict'].outputs['head2/probabilities'])) + self.assertAllClose( + expected_probabilities['head1'], + sess.run( + spec.export_outputs['head1/predict'].outputs['probabilities'])) + self.assertAllClose( + expected_probabilities['head2'], + sess.run( + spec.export_outputs['head2/predict'].outputs['probabilities'])) def test_predict_two_heads_logits_tensor(self): """Tests predict with logits as Tensor.""" @@ -181,8 +197,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', - 'head2', 'classification/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification', + 'head1/predict', 'head2', 'head2/classification', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -238,8 +254,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', - 'head2', 'regression/head2', 'predict/head2'), + (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression', + 'head1/predict', 'head2', 'head2/regression', 'head2/predict'), spec.export_outputs.keys()) # Assert predictions and export_outputs. @@ -283,10 +299,11 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15. expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 spec = multi_head.create_estimator_spec( @@ -300,14 +317,14 @@ class MultiHeadTest(test.TestCase): keys.LOSS + '/head1': expected_loss_head1, keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. - keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, + keys.LOSS_MEAN + '/head1': expected_loss_head1, + keys.LOSS_MEAN + '/head2': expected_loss_head2, # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, keys.AUC + '/head2': 0.3333, - keys.AUC_PR + '/head1': 0.49999964, - keys.AUC_PR + '/head2': 0.33333313, + keys.AUC_PR + '/head1': 0.6667, + keys.AUC_PR + '/head2': 0.5000, } # Assert spec contains expected tensors. @@ -347,8 +364,8 @@ class MultiHeadTest(test.TestCase): tol = 1e-3 with self.test_session(): # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] - # (averaged over classes, sum-reduced over examples). - self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) + # (averaged over classes, averaged over examples). + self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): # Use different example weighting for each head weighting. @@ -383,18 +400,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -431,18 +448,18 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # training_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5 # head-weighted unreduced_loss = 1 * [10, 7.5] self.assertAllClose( [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # training_loss = 2 * 20 + 3 * 10 = 70 + # training_loss = (2 * 20 + 3 * 10) / 2 = 35 # head-weighted unreduced_loss = 2 * [20, 10] self.assertAllClose( [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) - # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5 + self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol) # head-weighted example weights self.assertAllClose( [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) @@ -466,14 +483,14 @@ class MultiHeadTest(test.TestCase): [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), } # Loss for the first head: - # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + - # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 - # = 28 + # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8 + # = 3.5 # Loss for the second head: - # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + - # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 - # = 74 - expected_training_loss = 28. + 74. + # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12 + # = 6.167 + expected_training_loss = 3.5 + 6.167 training_loss = multi_head.create_loss( features={}, @@ -495,8 +512,8 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -530,10 +547,46 @@ class MultiHeadTest(test.TestCase): _assert_simple_summaries(self, { metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, }, summary_str, tol) + def test_train_one_head_with_optimizer(self): + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + multi_head = multi_head_lib.multi_head([head1]) + + logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} + labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 + expected_loss = 8.75 + expected_train_result = 'my_train_op' + + class _Optimizer(object): + + def minimize(self, loss, global_step): + del global_step + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + optimizer=_Optimizer()) + + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + def test_train_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -553,10 +606,12 @@ class MultiHeadTest(test.TestCase): # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] + # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] + # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15 # Average over classes, weighted sum over batch and heads. - expected_loss_head1 = 17.5 - expected_loss_head2 = 30.0 + expected_loss_head1 = 8.75 + expected_loss_head2 = 15.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -592,9 +647,6 @@ class MultiHeadTest(test.TestCase): metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, - # Average loss over examples. - metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, - metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, }, summary_str, tol) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index 7134cd3f5a457a322f51066eb791133c3181d3fb..a8774d6dab9205439e6e312827f9cd1306e3f1ea 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -110,7 +110,8 @@ def replicate_model_fn(model_fn, Certain algorithms were chosen for aggregating results of computations on multiple towers: - Losses from all towers are reduced according to `loss_reduction`. - - Gradients are reduced using sum for each trainable variable. + - Gradients from all towers are reduced according to `loss_reduction` + for each trainable variable. - `eval_metrics_ops` are reduced per metric using `reduce_mean`. - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are reduced using concatenation. @@ -135,7 +136,7 @@ def replicate_model_fn(model_fn, the train_op argument of `EstimatorSpec`. loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This - argument can be used to replice only on the subset of available GPUs. + argument can be used to replicate only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. @@ -455,7 +456,7 @@ def _get_local_devices(device_type): def _split_batch(features, labels, number_of_shards, device): - """Split input features and labes into batches.""" + """Split input features and labels into batches.""" def ensure_divisible_by_shards(sequence): batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0] @@ -601,7 +602,7 @@ def _local_device_setter(worker_device, ps_devices, ps_strategy): def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers): - """Produce an EstimatorSpec with approproriately scaled loss.""" + """Produce an EstimatorSpec with appropriately scaled loss.""" if tower_spec.loss is None: return tower_spec diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index d46a18aacfcd911c56a9f22dc9581060c7b458a6..144b45982c8aec2e2b115c812b24e8843d60ce1e 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import re import shutil import tempfile +from absl.testing import parameterized import numpy as np import six @@ -57,26 +58,19 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import training -# TODO(isaprykin): Parametrize all the tests on -# replicate_model_fn._VariableDistributionMode when it's supported. -class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase, + parameterized.TestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow_with_public_version(self): - return self._complete_flow_with_mode(mode=None) - - def test_complete_flow_with_mode_local_ps_server(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode. - SHARED_LOCAL_PARAMETER_SERVER) - - def test_complete_flow_with_mode_round_robin(self): - return self._complete_flow_with_mode( - replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) - - def _complete_flow_with_mode(self, mode): + @parameterized.named_parameters( + ('PublicInterface', None), + ('ParameterServerMode', replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER), + ('RoundRobinMode', + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN)) + def test_complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..b475c12f5af3aedc766a0880a98c5c1e29bddbb7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -0,0 +1,481 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Recurrent Neural Network estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import training_util + + +# The defaults are historical artifacts of the initial implementation, but seem +# reasonable choices. +_DEFAULT_LEARNING_RATE = 0.05 +_DEFAULT_CLIP_NORM = 5.0 + +_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell, + 'lstm': rnn_cell.BasicLSTMCell, + 'gru': rnn_cell.GRUCell} + +# Indicates no value was provided by the user to a kwarg. +USE_DEFAULT = object() + + +def _single_rnn_cell(num_units, cell_type): + cell_type = _CELL_TYPES.get(cell_type, cell_type) + if not cell_type or not issubclass(cell_type, rnn_cell.RNNCell): + raise ValueError('Supported cell types are {}; got {}'.format( + list(_CELL_TYPES.keys()), cell_type)) + return cell_type(num_units=num_units) + + +def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): + """Convenience function to create `rnn_cell_fn` for canned RNN Estimators. + + Args: + num_units: Iterable of integer number of hidden units per RNN layer. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. + + Returns: + A function that takes a single argument, an instance of + `tf.estimator.ModeKeys`, and returns an instance derived from + `tf.nn.rnn_cell.RNNCell`. + + Raises: + ValueError: If cell_type is not supported. + """ + def rnn_cell_fn(mode): + # Unused. Part of the rnn_cell_fn interface since user specified functions + # may need different behavior across modes (e.g. dropout). + del mode + cells = [_single_rnn_cell(n, cell_type) for n in num_units] + if len(cells) == 1: + return cells[0] + return rnn_cell.MultiRNNCell(cells) + return rnn_cell_fn + + +def _concatenate_context_input(sequence_input, context_input): + """Replicates `context_input` across all timesteps of `sequence_input`. + + Expands dimension 1 of `context_input` then tiles it `sequence_length` times. + This value is appended to `sequence_input` on dimension 2 and the result is + returned. + + Args: + sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, + padded_length, d0]`. + context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. + + Returns: + A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, + d0 + d1]`. + + Raises: + ValueError: If `sequence_input` does not have rank 3 or `context_input` does + not have rank 2. + """ + seq_rank_check = check_ops.assert_rank( + sequence_input, + 3, + message='sequence_input must have rank 3', + data=[array_ops.shape(sequence_input)]) + seq_type_check = check_ops.assert_type( + sequence_input, + dtypes.float32, + message='sequence_input must have dtype float32; got {}.'.format( + sequence_input.dtype)) + ctx_rank_check = check_ops.assert_rank( + context_input, + 2, + message='context_input must have rank 2', + data=[array_ops.shape(context_input)]) + ctx_type_check = check_ops.assert_type( + context_input, + dtypes.float32, + message='context_input must have dtype float32; got {}.'.format( + context_input.dtype)) + with ops.control_dependencies( + [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): + padded_length = array_ops.shape(sequence_input)[1] + tiled_context_input = array_ops.tile( + array_ops.expand_dims(context_input, 1), + array_ops.concat([[1], [padded_length], [1]], 0)) + return array_ops.concat([sequence_input, tiled_context_input], 2) + + +def _select_last_activations(activations, sequence_lengths): + """Selects the nth set of activations for each n in `sequence_length`. + + Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not + `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If + `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. + + Args: + activations: A `Tensor` with shape `[batch_size, padded_length, k]`. + sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`. + Returns: + A `Tensor` of shape `[batch_size, k]`. + """ + with ops.name_scope( + 'select_last_activations', values=[activations, sequence_lengths]): + activations_shape = array_ops.shape(activations) + batch_size = activations_shape[0] + padded_length = activations_shape[1] + output_units = activations_shape[2] + if sequence_lengths is None: + sequence_lengths = padded_length + start_indices = math_ops.to_int64( + math_ops.range(batch_size) * padded_length) + last_indices = start_indices + sequence_lengths - 1 + reshaped_activations = array_ops.reshape( + activations, [batch_size * padded_length, output_units]) + + last_activations = array_ops.gather(reshaped_activations, last_indices) + last_activations.set_shape([activations.shape[0], activations.shape[2]]) + return last_activations + + +def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, + context_feature_columns, input_layer_partitioner): + """Function builder for a rnn logit_fn. + + Args: + output_units: An int indicating the dimension of the logit layer. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell`. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. + context_feature_columns: An iterable containing the `FeatureColumn`s + that represent contextual input. + input_layer_partitioner: Partitioner for input layer. + + Returns: + A logit_fn (see below). + + Raises: + ValueError: If output_units is not an int. + """ + if not isinstance(output_units, int): + raise ValueError('output_units must be an int. Given type: {}'.format( + type(output_units))) + + def rnn_logit_fn(features, mode): + """Recurrent Neural Network logit_fn. + + Args: + features: This is the first item returned from the `input_fn` + passed to `train`, `evaluate`, and `predict`. This should be a + single `Tensor` or `dict` of same. + mode: Optional. Specifies if this training, evaluation or prediction. See + `ModeKeys`. + + Returns: + A `Tensor` representing the logits. + """ + with variable_scope.variable_scope( + 'sequence_input_layer', + values=tuple(six.itervalues(features)), + partitioner=input_layer_partitioner): + sequence_input, sequence_length = seq_fc.sequence_input_layer( + features=features, feature_columns=sequence_feature_columns) + summary.histogram('sequence_length', sequence_length) + + if context_feature_columns: + context_input = feature_column_lib.input_layer( + features=features, + feature_columns=context_feature_columns) + sequence_input = _concatenate_context_input(sequence_input, + context_input) + + cell = rnn_cell_fn(mode) + # Ignore output state. + rnn_outputs, _ = rnn.dynamic_rnn( + cell=cell, + inputs=sequence_input, + dtype=dtypes.float32, + time_major=False) + last_activations = _select_last_activations(rnn_outputs, sequence_length) + + with variable_scope.variable_scope('logits', values=(rnn_outputs,)): + logits = core_layers.dense( + last_activations, + units=output_units, + activation=None, + kernel_initializer=init_ops.glorot_uniform_initializer()) + return logits + + return rnn_logit_fn + + +def _rnn_model_fn(features, + labels, + mode, + head, + rnn_cell_fn, + sequence_feature_columns, + context_feature_columns, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Recurrent Neural Net model_fn. + + Args: + features: dict of `Tensor` and `SparseTensor` objects returned from + `input_fn`. + labels: `Tensor` of shape [batch_size, 1] or [batch_size] with labels. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + head: A `head_lib._Head` instance. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell`. + sequence_feature_columns: Iterable containing `FeatureColumn`s that + represent sequential model inputs. + context_feature_columns: Iterable containing `FeatureColumn`s that + represent model inputs not associated with a specific timestep. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use the Adagrad + optimizer with a default learning rate of 0.05 and gradient clip norm of + 5.0. + input_layer_partitioner: Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Returns: + An `EstimatorSpec` instance. + + Raises: + ValueError: If mode or optimizer is invalid, or features has the wrong type. + """ + if not isinstance(features, dict): + raise ValueError('features should be a dictionary of `Tensor`s. ' + 'Given type: {}'.format(type(features))) + + # If user does not provide an optimizer instance, use the optimizer specified + # by the string with default learning rate and gradient clipping. + if not isinstance(optimizer, optimizer_lib.Optimizer): + optimizer = optimizers.get_optimizer_instance( + optimizer, learning_rate=_DEFAULT_LEARNING_RATE) + optimizer = extenders.clip_gradients_by_norm(optimizer, _DEFAULT_CLIP_NORM) + + num_ps_replicas = config.num_ps_replicas if config else 0 + partitioner = partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas) + with variable_scope.variable_scope( + 'rnn', + values=tuple(six.itervalues(features)), + partitioner=partitioner): + input_layer_partitioner = input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas, + min_slice_size=64 << 20)) + + logit_fn = _rnn_logit_fn_builder( + output_units=head.logits_dimension, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + input_layer_partitioner=input_layer_partitioner) + logits = logit_fn(features=features, mode=mode) + + def _train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizer.minimize( + loss, + global_step=training_util.get_global_step()) + + return head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + train_op_fn=_train_op_fn, + logits=logits) + + +class RNNClassifier(estimator.Estimator): + """A classifier for TensorFlow RNN models. + + Trains a recurrent neural network model to classify instances into one of + multiple classes. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNClassifier( + num_units=[32, 16], cell_type='lstm', + sequence_feature_columns=[token_emb]) + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss is calculated by using softmax cross entropy. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + n_classes=2, + weight_column=None, + label_vocabulary=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + n_classes: Number of label classes. Defaults to 2, namely binary + classification. Must be > 1. + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. If it is a string, it is + used as a key to fetch weight tensor from the `features`. If it is a + `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, + then weight_column.normalizer_fn is applied on it to get weight tensor. + label_vocabulary: A list of strings represents possible label values. If + given, labels must be string type and have any value in + `label_vocabulary`. If it is not given, that means labels are + already encoded as integer or float within [0, 1] for `n_classes=2` and + encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . + Also there will be errors if vocabulary is not provided and labels are + string. + optimizer: An instance of `tf.Optimizer` used to train the model. Defaults + to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): + raise ValueError( + 'num_units and cell_type must not be specified when using rnn_cell_fn' + ) + if not rnn_cell_fn: + if cell_type == USE_DEFAULT: + cell_type = 'basic_rnn' + rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + + if n_classes == 2: + head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access + weight_column=weight_column, + label_vocabulary=label_vocabulary) + else: + head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access + n_classes, weight_column=weight_column, + label_vocabulary=label_vocabulary) + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNClassifier, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..393f94f5c7de02c56d93993bbeb8aaec4ea8234c --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -0,0 +1,1131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rnn.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import rnn +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import monitored_session +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_util + + +# Names of variables created by BasicRNNCell model. +TOKEN_EMBEDDING_NAME = 'rnn/sequence_input_layer/input_layer/tokens_sequential_embedding/embedding_weights' +CELL_WEIGHTS_NAME = 'rnn/rnn/basic_rnn_cell/kernel' +CELL_BIAS_NAME = 'rnn/rnn/basic_rnn_cell/bias' +MULTI_CELL_WEIGHTS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/kernel' +MULTI_CELL_BIAS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/bias' +LOGITS_WEIGHTS_NAME = 'rnn/logits/dense/kernel' +LOGITS_BIAS_NAME = 'rnn/logits/dense/bias' + + +def _assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def create_checkpoint(rnn_weights, rnn_biases, logits_weights, logits_biases, + global_step, model_dir): + """Create checkpoint file with provided model weights. + + Args: + rnn_weights: Iterable of values of weights for the RNN cell. + rnn_biases: Iterable of values of biases for the RNN cell. + logits_weights: Iterable of values for matrix connecting RNN output to + logits. + logits_biases: Iterable of values for logits bias term. + global_step: Initial global step to save in checkpoint. + model_dir: Directory into which checkpoint is saved. + """ + model_weights = {} + model_weights[CELL_WEIGHTS_NAME] = rnn_weights + model_weights[CELL_BIAS_NAME] = rnn_biases + model_weights[LOGITS_WEIGHTS_NAME] = logits_weights + model_weights[LOGITS_BIAS_NAME] = logits_biases + + with ops.Graph().as_default(): + # Create model variables. + for k, v in six.iteritems(model_weights): + variables_lib.Variable(v, name=k, dtype=dtypes.float32) + + # Create non-model variables. + global_step_var = training_util.create_global_step() + assign_op = global_step_var.assign(global_step) + + # Initialize vars and save checkpoint. + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=model_dir) as sess: + sess.run(assign_op) + + +class RNNLogitFnTest(test.TestCase): + """Tests correctness of logits calculated from _rnn_logit_fn_builder.""" + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_logits(self, mode, rnn_units, logits_dimension, features_fn, + sequence_feature_columns, context_feature_columns, + expected_logits): + """Tests that the expected logits are calculated.""" + with ops.Graph().as_default(): + # Global step needed for MonitoredSession, which is in turn used to + # explicitly set variable weights through a checkpoint. + training_util.create_global_step() + # Use a variable scope here with 'rnn', emulating the rnn model_fn, so + # the checkpoint naming is shared. + with variable_scope.variable_scope('rnn'): + input_layer_partitioner = ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=0, min_slice_size=64 << 20)) + logit_fn = rnn._rnn_logit_fn_builder( + output_units=logits_dimension, + rnn_cell_fn=rnn._make_rnn_cell_fn(rnn_units), + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + input_layer_partitioner=input_layer_partitioner) + # Features are constructed within this function, otherwise the Tensors + # containing the features would be defined outside this graph. + logits = logit_fn(features=features_fn(), mode=mode) + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=self._model_dir) as sess: + self.assertAllClose(expected_logits, sess.run(logits), atol=1e-4) + + def testOneDimLogits(self): + """Tests one-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10]], [[5]]] + initial_state = [0, 0] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)]] + = [[0.53, -0.37]] + logits = [[-1*0.53 - 1*0.37 + 0.3]] = [[-0.6033]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033]]) + + def testMultiDimLogits(self): + """Tests multi-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10]], [[5]]] + initial_state = [0, 0] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)]] + = [[0.53, -0.37]] + logits = [[-1*0.53 - 1*0.37 + 0.3], + [0.5*0.53 + 0.3*0.37 + 0.4], + [0.2*0.53 - 0.1*0.37 + 0.5] + = [[-0.6033, 0.7777, 0.5698]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=3, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033, 0.7777, 0.5698]]) + + def testMultiExampleMultiDim(self): + """Tests multiple examples and multi-dimensional logits. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10], [5]], [[2], [7]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + .2*0 + .3*0 +.2), + tanh(-.2*2 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91], [0.38, 0.10]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)], + [tanh(.1*7 + .2*.38 + .3*.10 +.2), + tanh(-.2*7 - .3*.38 - .4*.10 +.5)]] + = [[0.53, -0.37], [0.76, -0.78] + logits = [[-1*0.53 - 1*0.37 + 0.3, + 0.5*0.53 + 0.3*0.37 + 0.4, + 0.2*0.53 - 0.1*0.37 + 0.5], + [-1*0.76 - 1*0.78 + 0.3, + 0.5*0.76 +0.3*0.78 + 0.4, + 0.2*0.76 -0.1*0.78 + 0.5]] + = [[-0.6033, 0.7777, 0.5698], [-1.2473, 1.0170, 0.5745]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,)) + ] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=3, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033, 0.7777, 0.5698], + [-1.2473, 1.0170, 0.5745]]) + + def testMultiExamplesDifferentLength(self): + """Tests multiple examples with different lengths. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10], [5]], [[2], [0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2), + tanh(-.2*10 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + .2*0 + .3*0 +.2), + tanh(-.2*2 - .3*0 - .4*0 +.5)]] + = [[0.83, -0.91], [0.38, 0.10]] + rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2), + tanh(-.2*5 - .3*.83 + .4*.91 +.5)], + []] + = [[0.53, -0.37], []] + logits = [[-1*0.53 - 1*0.37 + 0.3], + [-1*0.38 + 1*0.10 + 0.3]] + = [[-0.6033], [0.0197]] + """ + base_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.6033], [0.0197]]) + + def testMultiExamplesWithContext(self): + """Tests multiple examples with context features. + + Intermediate values are rounded for ease in reading. + input_layer = [[[10, -0.5], [5, -0.5]], [[2, 0.8], [0, 0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.1*10 - 1*.5 + .2*0 + .3*0 +.2), + tanh(-.2*10 - 0.9*.5 - .3*0 - .4*0 +.5)], + [tanh(.1*2 + 1*.8 + .2*0 + .3*0 +.2), + tanh(-.2*2 + .9*.8 - .3*0 - .4*0 +.5)]] + = [[0.60, -0.96], [0.83, 0.68]] + rnn_output_timestep_2 = [[tanh(.1*5 - 1*.5 + .2*.60 - .3*.96 +.2), + tanh(-.2*5 - .9*.5 - .3*.60 + .4*.96 +.5)], + []] + = [[0.03, -0.63], []] + logits = [[-1*0.03 - 1*0.63 + 0.3], + [-1*0.83 + 1*0.68 + 0.3]] + = [[-0.3662], [0.1414]] + """ + base_global_step = 100 + create_checkpoint( + # Context features weights are inserted between input and state weights. + rnn_weights=[[.1, -.2], [1., 0.9], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + 'context': [[-0.5], [0.8]], + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + context_feature_columns = [fc.numeric_column('context', shape=(1,))] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-0.3662], [0.1414]]) + + def testMultiExamplesMultiFeatures(self): + """Tests examples with multiple sequential feature columns. + + Intermediate values are rounded for ease in reading. + input_layer = [[[1, 0, 10], [0, 1, 5]], [[1, 0, 2], [0, 0, 0]]] + initial_state = [[0, 0], [0, 0]] + rnn_output_timestep_1 = [[tanh(.5*1 + 1*0 + .1*10 + .2*0 + .3*0 +.2), + tanh(-.5*1 - 1*0 - .2*10 - .3*0 - .4*0 +.5)], + [tanh(.5*1 + 1*0 + .1*2 + .2*0 + .3*0 +.2), + tanh(-.5*1 - 1*0 - .2*2 - .3*0 - .4*0 +.5)]] + = [[0.94, -0.96], [0.72, -0.38]] + rnn_output_timestep_2 = [[tanh(.5*0 + 1*1 + .1*5 + .2*.94 - .3*.96 +.2), + tanh(-.5*0 - 1*1 - .2*5 - .3*.94 + .4*.96 +.5)], + []] + = [[0.92, -0.88], []] + logits = [[-1*0.92 - 1*0.88 + 0.3], + [-1*0.72 - 1*0.38 + 0.3]] + = [[-1.5056], [-0.7962]] + """ + base_global_step = 100 + create_checkpoint( + # FeatureColumns are sorted alphabetically, so on_sale weights are + # inserted before price. + rnn_weights=[[.5, -.5], [1., -1.], [.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=base_global_step, + model_dir=self._model_dir) + + def features_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + 'on_sale': + sparse_tensor.SparseTensor( + values=[0, 1, 0], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + } + + price_column = seq_fc.sequence_numeric_column('price', shape=(1,)) + on_sale_column = fc.indicator_column( + seq_fc.sequence_categorical_column_with_identity( + 'on_sale', num_buckets=2)) + sequence_feature_columns = [price_column, on_sale_column] + context_feature_columns = [] + + for mode in [ + model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.PREDICT + ]: + self._test_logits( + mode, + rnn_units=[2], + logits_dimension=1, + features_fn=features_fn, + sequence_feature_columns=sequence_feature_columns, + context_feature_columns=context_feature_columns, + expected_logits=[[-1.5056], [-0.7962]]) + + +class RNNClassifierTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _assert_checkpoint( + self, n_classes, input_units, cell_units, expected_global_step): + + shapes = { + name: shape for (name, shape) in + checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual( + expected_global_step, + checkpoint_utils.load_variable( + self._model_dir, ops.GraphKeys.GLOBAL_STEP)) + + # RNN Cell variables. + if len(cell_units) > 1: + for i, cell_unit in enumerate(cell_units): + self.assertEqual([input_units + cell_unit, cell_unit], + shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i]) + self.assertEqual([cell_unit], + shapes[MULTI_CELL_BIAS_NAME_PATTERN % i]) + input_units = cell_unit + elif len(cell_units) == 1: + self.assertEqual([input_units + cell_unit, cell_unit], + shapes[CELL_WEIGHTS_NAME]) + self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME]) + + # Logits variables. + logits_dimension = n_classes if n_classes > 2 else 1 + self.assertEqual([cell_units[-1], logits_dimension], + shapes[LOGITS_WEIGHTS_NAME]) + self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME]) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s/part_0:0' % CELL_BIAS_NAME, + '%s/part_0:0' % CELL_WEIGHTS_NAME, + '%s/part_0:0' % LOGITS_BIAS_NAME, + '%s/part_0:0' % LOGITS_WEIGHTS_NAME, + ] + + def _minimize(loss, global_step): + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual( + expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + return state_ops.assign_add(global_step, 1).op + assert_loss = _assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + return state_ops.assign_add(global_step, 1).op + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def testConflictingRNNCellFn(self): + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + cell_units = [4, 2] + + with self.assertRaisesRegexp( + ValueError, + 'num_units and cell_type must not be specified when using rnn_cell_fn'): + rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=lambda x: x, + num_units=cell_units) + + with self.assertRaisesRegexp( + ValueError, + 'num_units and cell_type must not be specified when using rnn_cell_fn'): + rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=lambda x: x, + cell_type='lstm') + + def _testFromScratchWithDefaultOptimizer(self, n_classes): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat'], + indices=[[0, 0], [0, 1], [0, 2]], + dense_shape=[1, 3]), + }, [[1]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + + cell_units = [4, 2] + est = rnn.RNNClassifier( + sequence_feature_columns=[embed], + num_units=cell_units, + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def testBinaryClassFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=2) + + def testMultiClassFromScratchWithDefaultOptimizer(self): + self._testFromScratchWithDefaultOptimizer(n_classes=4) + + def testFromScratchWithCustomRNNCellFn(self): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat'], + indices=[[0, 0], [0, 1], [0, 2]], + dense_shape=[1, 3]), + }, [[1]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + cell_units = [4, 2] + n_classes = 2 + + def rnn_cell_fn(mode): + del mode # unused + cells = [rnn_cell.BasicRNNCell(num_units=n) for n in cell_units] + return rnn_cell.MultiRNNCell(cells) + + est = rnn.RNNClassifier( + sequence_feature_columns=[embed], + rnn_cell_fn=rnn_cell_fn, + n_classes=n_classes, + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def _testExampleWeight(self, n_classes): + def train_input_fn(): + return { + 'tokens': + sparse_tensor.SparseTensor( + values=['the', 'cat', 'sat', 'dog', 'barked'], + indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + dense_shape=[2, 3]), + 'w': [[1], [2]], + }, [[1], [0]] + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + input_units = 2 + + cell_units = [4, 2] + est = rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=[embed], + n_classes=n_classes, + weight_column='w', + model_dir=self._model_dir) + + # Train for a few steps, and validate final checkpoint. + num_steps = 10 + est.train(input_fn=train_input_fn, steps=num_steps) + self._assert_checkpoint(n_classes, input_units, cell_units, num_steps) + + def testBinaryClassWithExampleWeight(self): + self._testExampleWeight(n_classes=2) + + def testMultiClassWithExampleWeight(self): + self._testExampleWeight(n_classes=4) + + def testBinaryClassFromCheckpoint(self): + initial_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=initial_global_step, + model_dir=self._model_dir) + + def train_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + }, [[0], [1]] + + # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics. + # See that test for loss calculation. + mock_optimizer = self._mock_optimizer(expected_loss=1.119661) + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + est.train(input_fn=train_input_fn, steps=10) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + def testMultiClassFromCheckpoint(self): + initial_global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=initial_global_step, + model_dir=self._model_dir) + + def train_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }, [[0], [1]] + + # Uses same checkpoint and examples as testMultiClassEvaluationMetrics. + # See that test for loss calculation. + mock_optimizer = self._mock_optimizer(expected_loss=2.662932) + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + optimizer=mock_optimizer, + model_dir=self._model_dir) + self.assertEqual(0, mock_optimizer.minimize.call_count) + est.train(input_fn=train_input_fn, steps=10) + self.assertEqual(1, mock_optimizer.minimize.call_count) + + +def sorted_key_dict(unsorted_dict): + return {k: unsorted_dict[k] for k in sorted(unsorted_dict)} + + +class RNNClassifierEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def testBinaryClassEvaluationMetrics(self): + global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=global_step, + model_dir=self._model_dir) + + def eval_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2.], + indices=[[0, 0], [0, 1], [1, 0]], + dense_shape=[2, 2]), + }, [[0], [1]] + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + model_dir=self._model_dir) + eval_metrics = est.evaluate(eval_input_fn, steps=1) + + # Uses identical numbers to testMultiExamplesWithDifferentLength. + # See that test for logits calculation. + # logits = [[-0.603282], [0.019719]] + # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]] + # loss = -label * ln(p) - (1 - label) * ln(1 - p) + # = [[0.436326], [0.683335]] + expected_metrics = { + ops.GraphKeys.GLOBAL_STEP: global_step, + metric_keys.MetricKeys.LOSS: 1.119661, + metric_keys.MetricKeys.LOSS_MEAN: 0.559831, + metric_keys.MetricKeys.ACCURACY: 1.0, + metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262, + metric_keys.MetricKeys.LABEL_MEAN: 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + # With default threshold of 0.5, the model is a perfect classifier. + metric_keys.MetricKeys.RECALL: 1.0, + metric_keys.MetricKeys.PRECISION: 1.0, + # Positive example is scored above negative, so AUC = 1.0. + metric_keys.MetricKeys.AUC: 1.0, + metric_keys.MetricKeys.AUC_PR: 1.0, + } + self.assertAllClose( + sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) + + def testMultiClassEvaluationMetrics(self): + global_step = 100 + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=global_step, + model_dir=self._model_dir) + + def eval_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5., 2., 7.], + indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + dense_shape=[2, 2]), + }, [[0], [1]] + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + model_dir=self._model_dir) + eval_metrics = est.evaluate(eval_input_fn, steps=1) + + # Uses identical numbers to testMultiExampleMultiDim. + # See that test for logits calculation. + # logits = [[-0.603282, 0.777708, 0.569756], + # [-1.247356, 1.017018, 0.574481]] + # logits_exp = exp(logits) / (1 + exp(logits)) + # = [[0.547013, 2.176468, 1.767836], + # [0.287263, 2.764937, 1.776208]] + # softmax_probabilities = logits_exp / logits_exp.sum() + # = [[0.121793, 0.484596, 0.393611], + # [0.059494, 0.572639, 0.367866]] + # loss = -1. * log(softmax[label]) + # = [[2.105432], [0.557500]] + expected_metrics = { + ops.GraphKeys.GLOBAL_STEP: global_step, + metric_keys.MetricKeys.LOSS: 2.662932, + metric_keys.MetricKeys.LOSS_MEAN: 1.331466, + metric_keys.MetricKeys.ACCURACY: 0.5, + } + + self.assertAllClose( + sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) + + +class RNNClassifierPredictionTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def testBinaryClassPredictions(self): + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1.], [1.]], + logits_biases=[0.3], + global_step=0, + model_dir=self._model_dir) + + def predict_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + label_vocabulary = ['class_0', 'class_1'] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=2, + label_vocabulary=label_vocabulary, + model_dir=self._model_dir) + # Uses identical numbers to testOneDimLogits. + # See that test for logits calculation. + # logits = [-0.603282] + # logistic = exp(-0.6033) / (1 + exp(-0.6033)) = [0.353593] + # probabilities = [0.646407, 0.353593] + # class_ids = argmax(probabilities) = [0] + predictions = next(est.predict(predict_input_fn)) + self.assertAllClose([-0.603282], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose([0.353593], + predictions[prediction_keys.PredictionKeys.LOGISTIC]) + self.assertAllClose( + [0.646407, 0.353593], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([0], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertEqual([b'class_0'], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + def testMultiClassPredictions(self): + create_checkpoint( + rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]], + rnn_biases=[.2, .5], + logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]], + logits_biases=[0.3, 0.4, 0.5], + global_step=0, + model_dir=self._model_dir) + + def predict_input_fn(): + return { + 'price': + sparse_tensor.SparseTensor( + values=[10., 5.], + indices=[[0, 0], [0, 1]], + dense_shape=[1, 2]), + } + + sequence_feature_columns = [ + seq_fc.sequence_numeric_column('price', shape=(1,))] + label_vocabulary = ['class_0', 'class_1', 'class_2'] + + est = rnn.RNNClassifier( + num_units=[2], + sequence_feature_columns=sequence_feature_columns, + n_classes=3, + label_vocabulary=label_vocabulary, + model_dir=self._model_dir) + # Uses identical numbers to testMultiDimLogits. + # See that test for logits calculation. + # logits = [-0.603282, 0.777708, 0.569756] + # logits_exp = exp(logits) = [0.547013, 2.176468, 1.767836] + # softmax_probabilities = logits_exp / logits_exp.sum() + # = [0.121793, 0.484596, 0.393611] + # class_ids = argmax(probabilities) = [1] + predictions = next(est.predict(predict_input_fn)) + self.assertAllClose([-0.603282, 0.777708, 0.569756], + predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose( + [0.121793, 0.484596, 0.393611], + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose([1], + predictions[prediction_keys.PredictionKeys.CLASS_IDS]) + self.assertEqual([b'class_1'], + predictions[prediction_keys.PredictionKeys.CLASSES]) + + +class RNNClassifierIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, n_classes, + batch_size): + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + + cell_units = [4, 2] + est = rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUATE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + # EXPORT + feature_spec = { + 'tokens': parsing_ops.VarLenFeature(dtypes.string), + 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), + } + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def testNumpyInputFn(self): + """Tests complete flow with numpy_input_fn.""" + n_classes = 3 + batch_size = 10 + words = ['dog', 'cat', 'bird', 'the', 'a', 'sat', 'flew', 'slept'] + # Numpy only supports dense input, so all examples will have same length. + # TODO(b/73160931): Update test when support for prepadded data exists. + sequence_length = 3 + + features = [] + for _ in range(batch_size): + sentence = random.sample(words, sequence_length) + features.append(sentence) + + x_data = np.array(features) + y_data = np.random.randint(n_classes, size=batch_size) + + train_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'tokens': x_data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + n_classes=n_classes, + batch_size=batch_size) + + def testParseExampleInputFn(self): + """Tests complete flow with input_fn constructed from parse_example.""" + n_classes = 3 + batch_size = 10 + words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept'] + + serialized_examples = [] + for _ in range(batch_size): + sequence_length = random.randint(1, len(words)) + sentence = random.sample(words, sequence_length) + label = random.randint(0, n_classes - 1) + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'tokens': + feature_pb2.Feature(bytes_list=feature_pb2.BytesList( + value=sentence)), + 'label': + feature_pb2.Feature(int64_list=feature_pb2.Int64List( + value=[label])), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'tokens': parsing_ops.VarLenFeature(dtypes.string), + 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), + } + def _train_input_fn(): + features = parsing_ops.parse_example(serialized_examples, feature_spec) + labels = features.pop('label') + return features, labels + def _eval_input_fn(): + features = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + labels = features.pop('label') + return features, labels + def _predict_input_fn(): + features = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + features.pop('label') + return features, None + + self._test_complete_flow( + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + n_classes=n_classes, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 180f1b68f3b56113dfbbfc100bd04efc3bb8b31f..0a648d5d40e431bedb42017b15cabe078ac22fa7 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -66,6 +66,7 @@ tf_custom_op_py_library( "//tensorflow/python:variables", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -223,7 +224,10 @@ py_test( srcs = ["python/ops/kmeans_test.py"], shard_count = 4, srcs_version = "PY2AND3", - tags = ["notsan"], # b/67512932 + tags = [ + "nomac", # b/73741358 + "notsan", # b/67512932 + ], deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", @@ -238,6 +242,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:training", "//tensorflow/python/estimator:run_config", + "//tensorflow/python/feature_column:feature_column_py", "//third_party/py/numpy", ], ) @@ -342,16 +347,3 @@ cuda_py_test( ], main = "python/kernel_tests/masked_matmul_benchmark.py", ) - -# All files -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD index bbe842bd5ccc7357805adda1df42ba8799fcd8f2..363baa121ab3854a802ca3606e35597d31b35a57 100644 --- a/tensorflow/contrib/factorization/examples/BUILD +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -21,14 +21,3 @@ tf_py_test( ], tags = ["notsan"], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD index 44eab56011dad2f6fbe843b3569b4acc5c5e542a..ea8b9a17a27093cb57564861815edd6ecb18a014 100644 --- a/tensorflow/contrib/factorization/kernels/BUILD +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -67,14 +67,3 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index dd61f59585aee2e0245cfd6797b313b972c19bc5..2a6c97e8b9526894eba057505a2bf823ad778f56 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -353,7 +353,7 @@ class NearestNeighborsOp : public OpKernel { auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); const int64 num_threads = worker_threads.num_threads; // This kernel might be configured to use fewer than the total number of - // available CPUs on the host machine. To avoid descructive interference + // available CPUs on the host machine. To avoid destructive interference // with other jobs running on the host machine, we must only use a fraction // of total available L3 cache. Unfortunately, we cannot query the host // machine to get the number of physical CPUs. So, we use a fixed per-CPU diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 23137e0a973c0bdd2cdbd97159f7fd310178bf54..84e80791f4991ad2b67d0a00ee1e00cf0d0daadc 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -41,11 +41,12 @@ from tensorflow.python.platform import resource_loader _clustering_ops = loader.load_op_library( resource_loader.get_path_to_datafile('_clustering_ops.so')) -# Euclidean distance between vectors U and V is defined as ||U - V||_F which is -# the square root of the sum of the absolute squares of the elements difference. +# Euclidean distance between vectors U and V is defined as \\(||U - V||_F\\) +# which is the square root of the sum of the absolute squares of the elements +# difference. SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' # Cosine distance between vectors U and V is defined as -# 1 - (U \dot V) / (||U||_F ||V||_F) +# \\(1 - (U \dot V) / (||U||_F ||V||_F)\\) COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' @@ -472,8 +473,8 @@ class KMeans(object): # Locally compute the sum of inputs mapped to each id. # For a cluster with old cluster value x, old count n, and with data # d_1,...d_k newly assigned to it, we recompute the new value as - # x += (sum_i(d_i) - k * x) / (n + k). - # Compute sum_i(d_i), see comment above. + # \\(x += (sum_i(d_i) - k * x) / (n + k)\\). + # Compute \\(sum_i(d_i)\\), see comment above. cluster_center_updates = math_ops.unsorted_segment_sum( inp, unique_idx, num_unique_cluster_idx) # Shape to enable broadcasting count_updates and learning_rate to inp. diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 054888e734086c153f7af59f4548d4d20abab813..811fa89bc38c61b16710a441b99d9e5dfac67668 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -51,9 +51,9 @@ class WALSModel(object): r"""A model for Weighted Alternating Least Squares matrix factorization. It minimizes the following loss function over U, V: - \\( - \|\sqrt W \odot (A - U V^T) \|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) - )\\ + $$ + \|\sqrt W \odot (A - U V^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) + $$ where, A: input matrix, W: weight matrix. Note that the (element-wise) square root of the weights @@ -61,12 +61,12 @@ class WALSModel(object): U, V: row_factors and column_factors matrices, \\(\lambda)\\: regularization. Also we assume that W is of the following special form: - \\( W_{ij} = W_0 + R_i * C_j )\\ if \\(A_{ij} \ne 0)\\, - \\(W_{ij} = W_0)\\ otherwise. + \\( W_{ij} = W_0 + R_i * C_j \\) if \\(A_{ij} \ne 0\\), + \\(W_{ij} = W_0\\) otherwise. where, - \\(W_0)\\: unobserved_weight, - \\(R_i)\\: row_weights, - \\(C_j)\\: col_weights. + \\(W_0\\): unobserved_weight, + \\(R_i\\): row_weights, + \\(C_j\\): col_weights. Note that the current implementation supports two operation modes: The default mode is for the condition where row_factors and col_factors can individually @@ -82,14 +82,15 @@ class WALSModel(object): normalized as follows: _, _, unregularized_loss, regularization, sum_weights = update_row_factors(sp_input) - if sp_input contains the rows {A_i, i \in I}, and the input matrix A has n - total rows, then the minibatch loss = unregularized_loss + regularization is - \\( + if sp_input contains the rows \\({A_i, i \in I}\\), and the input matrix A + has n total rows, then the minibatch loss = unregularized_loss + + regularization is + $$ (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2 - )\\ + $$ The sum_weights tensor contains the normalized sum of weights - sum(W_I) * n / |I|. + \\(sum(W_I) * n / |I|\\). A typical usage example (pseudocode): @@ -106,7 +107,7 @@ class WALSModel(object): # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init - # To be run once per interation sweep before the row(column) update + # To be run once per integration sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. @@ -118,9 +119,9 @@ class WALSModel(object): init_row_update_op = model.initialize_row_update_op init_col_update_op = model.initialize_col_update_op - # Ops to upate row(column). This can either take the entire sparse tensor - # or slices of sparse tensor. For distributed trainer, each trainer - # handles just part of the matrix. + # Ops to update row(column). This can either take the entire sparse + # tensor or slices of sparse tensor. For distributed trainer, each + # trainer handles just part of the matrix. _, row_update_op, unreg_row_loss, row_reg, _ = model.update_row_factors( sp_input=matrix_slices_from_queue_for_worker_shard) row_loss = unreg_row_loss + row_reg @@ -220,10 +221,10 @@ class WALSModel(object): in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of inner lists matching the number of row factor shards and the elements in each inner list are the weights for the rows of the corresponding row - factor shard. In this case, w_ij = unonbserved_weight + + factor shard. In this case, w_ij = unobserved_weight + row_weights[i] * col_weights[j]. - If this is a single non-negative real number, this value is used for - all row weights and w_ij = unobserved_weight + row_weights * + all row weights and \\(w_ij\\) = unobserved_weight + row_weights * col_weights[j]. Note that it is allowed to have row_weights as a list while col_weights a single number or vice versa. @@ -435,7 +436,7 @@ class WALSModel(object): gramian: Variable storing the gramian calculated from the factors. Returns: - A op that updates the gramian with the calcuated value from the factors. + A op that updates the gramian with the calculated value from the factors. """ partial_gramians = [] for f in factors: @@ -564,7 +565,7 @@ class WALSModel(object): Note that specifically this initializes the cache of the row and column weights on workers when `use_factors_weights_cache` is True. In this case, - if these weights are being calcualted and reset after the object is created, + if these weights are being calculated and reset after the object is created, it is important to ensure this ops is run afterwards so the cache reflects the correct values. """ @@ -665,18 +666,18 @@ class WALSModel(object): factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization - term. If sp_input contains the rows {A_{i, :}, i \in I}, and the input - matrix A has n total rows, then the unregularized loss is: - (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I| + term. If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the + input matrix A has n total rows, then the unregularized loss is: + \\(\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I|\\) The total loss is unregularized_loss + regularization. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. - If sp_input contains the rows {A_{i, :}, i \in I}, and the input matrix - A has n total rows, then the regularization term is: - \lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2. + If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the input + matrix A has n total rows, then the regularization term is: + \\(\lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2\\). sum_weights: The sum of the weights W_I corresponding to sp_input, - normalized by a factor of n / |I|. The root weighted squared error is: - \sqrt(unregularized_loss / sum_weights). + normalized by a factor of \\(n / |I|\\). The root weighted squared + error is: \sqrt(unregularized_loss / sum_weights). """ return self._process_input_helper( True, sp_input=sp_input, transpose_input=transpose_input) @@ -698,18 +699,18 @@ class WALSModel(object): factors. unregularized_loss: A tensor (scalar) that contains the normalized minibatch loss corresponding to sp_input, without the regularization - term. If sp_input contains the columns {A_{:, j}, j \in J}, and the - input matrix A has m total columns, then the unregularized loss is: - (\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I| + term. If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and + the input matrix A has m total columns, then the unregularized loss is: + \\(\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I|\\) The total loss is unregularized_loss + regularization. regularization: A tensor (scalar) that contains the normalized regularization term for the minibatch loss corresponding to sp_input. - If sp_input contains the columns {A_{:, j}, j \in J}, and the input - matrix A has m total columns, then the regularization term is: - \lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2. + If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and the + input matrix A has m total columns, then the regularization term is: + \\(\lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2\\). sum_weights: The sum of the weights W_J corresponding to sp_input, - normalized by a factor of m / |J|. The root weighted squared error is: - \sqrt(unregularized_loss / sum_weights). + normalized by a factor of \\(m / |J|\\). The root weighted squared + error is: \sqrt(unregularized_loss / sum_weights). """ return self._process_input_helper( False, sp_input=sp_input, transpose_input=transpose_input) @@ -720,8 +721,8 @@ class WALSModel(object): projection_weights=None): """Projects the row factors. - This computes the row embedding u_i for an observed row a_i by solving - one iteration of the update equations. + This computes the row embedding \\(u_i\\) for an observed row \\(a_i\\) by + solving one iteration of the update equations. Args: sp_input: A SparseTensor representing a set of rows. Please note that the @@ -753,8 +754,8 @@ class WALSModel(object): projection_weights=None): """Projects the column factors. - This computes the column embedding v_j for an observed column a_j by solving - one iteration of the update equations. + This computes the column embedding \\(v_j\\) for an observed column + \\(a_j\\) by solving one iteration of the update equations. Args: sp_input: A SparseTensor representing a set of columns. Please note that @@ -938,7 +939,7 @@ class WALSModel(object): loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) if transpose_input else new_sp_input) # sp_approx is the low rank estimate of the input matrix, formed by - # computing the product for (i, j) in loss_sp_input.indices. + # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices. sp_approx_vals = gen_factorization_ops.masked_matmul( new_left_values, right, diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index c8137339155ef1da8ee53967eea84a550f12ecbc..bb5140aeb3bf0238ca7cb52067ea6328dd1736d5 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -210,7 +210,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -283,8 +283,8 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 3 column feature vectors. - # This is expected to reprodue the same column factors in the model as the - # weights and feature vectors are identical to that used in model + # This is expected to reproduce the same column factors in the model as + # the weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, @@ -385,7 +385,7 @@ class WalsModelTest(test.TestCase): # Test row projection. # Using the specified projection weights for the 2 row feature vectors. - # This is expected to reprodue the same row factors in the model as the + # This is expected to reproduce the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( @@ -462,8 +462,8 @@ class WalsModelTest(test.TestCase): # Test column projection. # Using the specified projection weights for the 2 column feature vectors. - # This is expected to reprodue the same column factors in the model as the - # weights and feature vectors are identical to that used in model + # This is expected to reproduce the same column factors in the model as + # the weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index 98d6434f4752b224201e38bed05ccd14428a758b..ccdd679d6ae197ab21d5fed5e89daa26bc15a809 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -54,10 +54,10 @@ def _covariance(x, diag): diagonal matrix just the diagonal is returned. """ num_points = math_ops.to_float(array_ops.shape(x)[0]) - x -= math_ops.reduce_mean(x, 0, keep_dims=True) + x -= math_ops.reduce_mean(x, 0, keepdims=True) if diag: cov = math_ops.reduce_sum( - math_ops.square(x), 0, keep_dims=True) / (num_points - 1) + math_ops.square(x), 0, keepdims=True) / (num_points - 1) else: cov = math_ops.matmul(x, x, transpose_a=True) / (num_points - 1) return cov @@ -280,7 +280,7 @@ class GmmAlgorithm(object): self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): - """Defines the full covariance probabilties per example in a class. + """Defines the full covariance probabilities per example in a class. Updates a matrix with dimension num_examples X num_classes. @@ -313,7 +313,7 @@ class GmmAlgorithm(object): # TODO(xavigonzalvo): look into alternatives to log for # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( - math_ops.log(self._covs + 1e-3), 1, keep_dims=True) + math_ops.log(self._covs + 1e-3), 1, keepdims=True) diff = shard - self._means x2 = math_ops.square(diff) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) @@ -344,21 +344,21 @@ class GmmAlgorithm(object): def _define_prior_log_prob_operation(self, shard_id): """Computes the prior probability of all samples. - Updates a vector where each item is the prior probabibility of an + Updates a vector where each item is the prior probability of an input example. Args: shard_id: id of current shard_id. """ self._prior_probs[shard_id] = math_ops.reduce_logsumexp( - self._probs[shard_id], axis=1, keep_dims=True) + self._probs[shard_id], axis=1, keepdims=True) def _define_expectation_operation(self, shard_id): # Shape broadcasting. probs = array_ops.expand_dims(self._probs[shard_id], 0) # Membership weights are computed as: - # w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)} - # {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)} + # $$w_{ik} = \frac{\alpha_k f(\mathbf{y_i}|\mathbf{\theta}_k)}$$ + # $$ {\sum_{m=1}^{K}\alpha_mf(\mathbf{y_i}|\mathbf{\theta}_m)}$$ # where "i" is the i-th example, "k" is the k-th mixture, theta are # the model parameters and y_i the observations. # These are defined for each shard. @@ -375,7 +375,7 @@ class GmmAlgorithm(object): """ # Soft assignment of each data point to each of the two clusters. self._points_in_k[shard_id] = math_ops.reduce_sum( - self._w[shard_id], 0, keep_dims=True) + self._w[shard_id], 0, keepdims=True) # Partial means. w_mul_x = array_ops.expand_dims( math_ops.matmul( @@ -454,7 +454,7 @@ class GmmAlgorithm(object): for shard_id, prior_probs in enumerate(self._prior_probs): op.append(prior_probs + math_ops.log(self._w[shard_id])) self._scores = array_ops.squeeze( - math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0) + math_ops.reduce_logsumexp(op, axis=2, keepdims=True), axis=0) def gmm(inp, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 00a4734eb6d89cd02484f1c5161366377cc71208..4fc9c96e9d0a317ef757d5e1bb6563ed7c8832af 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -210,7 +210,7 @@ class GMMTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index c861cfff544a78617aa1ace730b50c094cf16330..9ffdd3ba5e8ac496533d0207f2b6846dbc92bc89 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -26,6 +26,7 @@ from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.export import export_output +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -61,8 +62,8 @@ class _LossRelativeChangeHook(session_run_hook.SessionRunHook): loss = run_values.results assert loss is not None if self._prev_loss: - relative_change = (abs(loss - self._prev_loss) / - (1 + abs(self._prev_loss))) + relative_change = ( + abs(loss - self._prev_loss) / (1 + abs(self._prev_loss))) if relative_change < self._tolerance: run_context.request_stop() self._prev_loss = loss @@ -105,24 +106,32 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): logging.info(e) -def _parse_tensor_or_dict(features): +def _parse_features_if_necessary(features, feature_columns): """Helper function to convert the input points into a usable format. Args: - features: The input points. + features: The input features. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column instances + that can be passed to `tf.feature_column.input_layer`. If this is None, + all features will be used. Returns: - If `features` is a dict of `k` features, each of which is a vector of `n` - scalars, the return value is a Tensor of shape `(n, k)` representing `n` - input points, where the items in the `k` dimension are sorted - lexicographically by `features` key. If `features` is not a dict, it is - returned unmodified. + If `features` is a dict of `k` features (optionally filtered by + `feature_columns`), each of which is a vector of `n` scalars, the return + value is a Tensor of shape `(n, k)` representing `n` input points, where the + items in the `k` dimension are sorted lexicographically by `features` key. + If `features` is not a dict, it is returned unmodified. """ - if isinstance(features, dict): - keys = sorted(features.keys()) - with ops.colocate_with(features[keys[0]]): - features = array_ops.concat([features[k] for k in keys], axis=1) - return features + if not isinstance(features, dict): + return features + + if feature_columns: + return fc.input_layer(features, feature_columns) + + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + return array_ops.concat([features[k] for k in keys], axis=1) class _ModelFn(object): @@ -130,7 +139,8 @@ class _ModelFn(object): def __init__(self, num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance): + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns): self._num_clusters = num_clusters self._initial_clusters = initial_clusters self._distance_metric = distance_metric @@ -139,6 +149,7 @@ class _ModelFn(object): self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries self._relative_tolerance = relative_tolerance + self._feature_columns = feature_columns def model_fn(self, features, mode, config): """Model function for the estimator. @@ -166,7 +177,7 @@ class _ModelFn(object): # input_points is a single Tensor. Therefore, the sharding functionality # in clustering_ops is unused, and some of the values below are lists of a # single item. - input_points = _parse_tensor_or_dict(features) + input_points = _parse_features_if_necessary(features, self._feature_columns) # Let N = the number of input_points. # all_distances: A list of one matrix of shape (N, num_clusters). Each value @@ -233,7 +244,57 @@ class _ModelFn(object): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + Example: + ``` + import numpy as np + import tensorflow as tf + + num_points = 100 + dimensions = 2 + points = np.random.uniform(0, 1000, [num_points, dimensions]) + + def input_fn(): + return tf.train.limit_epochs( + tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1) + + num_clusters = 5 + kmeans = tf.contrib.factorization.KMeansClustering( + num_clusters=num_clusters, use_mini_batch=False) + + # train + num_iterations = 10 + previous_centers = None + for _ in xrange(num_iterations): + kmeans.train(input_fn) + cluster_centers = kmeans.cluster_centers() + if previous_centers is not None: + print 'delta:', cluster_centers - previous_centers + previous_centers = cluster_centers + print 'score:', kmeans.score(input_fn) + print 'cluster centers:', cluster_centers + + # map the input points to their clusters + cluster_indices = list(kmeans.predict_cluster_index(input_fn)) + for i, point in enumerate(points): + cluster_index = cluster_indices[i] + center = cluster_centers[cluster_index] + print 'point:', point, 'is in cluster', cluster_index, 'centered at', center + ``` + + The `SavedModel` saved by the `export_savedmodel` method does not include the + cluster centers. However, the cluster centers may be retrieved by the + latest checkpoint saved during training. Specifically, + ``` + kmeans.cluster_centers() + ``` + is equivalent to + ``` + tf.train.load_variable( + kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME) + ``` + """ # Valid values for the distance_metric constructor argument. SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE @@ -253,6 +314,9 @@ class KMeansClustering(estimator.Estimator): CLUSTER_INDEX = 'cluster_index' ALL_DISTANCES = 'all_distances' + # Variable name used by cluster_centers(). + CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME + def __init__(self, num_clusters, model_dir=None, @@ -263,7 +327,8 @@ class KMeansClustering(estimator.Estimator): mini_batch_steps_per_iteration=1, kmeans_plus_plus_num_retries=2, relative_tolerance=None, - config=None): + config=None, + feature_columns=None): """Creates an Estimator for running KMeans training and inference. This Estimator implements the following variants of the K-means algorithm: @@ -309,11 +374,11 @@ class KMeansClustering(estimator.Estimator): than `num_clusters`, a TensorFlow runtime error occurs. distance_metric: The distance metric used for clustering. One of: * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance - between vectors `u` and `v` is defined as `||u - v||_2` which is - the square root of the sum of the absolute squares of the elements' - difference. + between vectors `u` and `v` is defined as \\(||u - v||_2\\) + which is the square root of the sum of the absolute squares of + the elements' difference. * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors - `u` and `v` is defined as `1 - (u . v) / (||u||_2 ||v||_2)`. + `u` and `v` is defined as \\(1 - (u . v) / (||u||_2 ||v||_2)\\). random_seed: Python integer. Seed for PRNG used to initialize centers. use_mini_batch: A boolean specifying whether to use the mini-batch k-means algorithm. See explanation above. @@ -330,6 +395,10 @@ class KMeansClustering(estimator.Estimator): iterations. Stops learning if the loss changes less than this amount. This may not work correctly if `use_mini_batch=True`. config: See @{tf.estimator.Estimator}. + feature_columns: An optionable iterable containing all the feature columns + used by the model. All items in the set should be feature column + instances that can be passed to `tf.feature_column.input_layer`. If this + is None, all features will be used. Raises: ValueError: An invalid argument was passed to `initial_clusters` or @@ -349,7 +418,8 @@ class KMeansClustering(estimator.Estimator): model_fn=_ModelFn( num_clusters, initial_clusters, distance_metric, random_seed, use_mini_batch, mini_batch_steps_per_iteration, - kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + kmeans_plus_plus_num_retries, relative_tolerance, + feature_columns).model_fn, model_dir=model_dir, config=config) @@ -406,4 +476,4 @@ class KMeansClustering(estimator.Estimator): def cluster_centers(self): """Returns the cluster centers.""" - return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) + return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index f9598bfc08c05ea3bba88b3135da0cf2e6bb0c95..88eb9cf692992fe2e1fc4f060ac98dd721c22307 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -27,6 +27,7 @@ from sklearn.cluster import KMeans as SklearnKMeans # pylint: disable=g-import-not-at-top from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib from tensorflow.python.estimator import run_config +from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -226,6 +227,44 @@ class KMeansTest(KMeansTestBase): self._infer_helper(kmeans, clusters, 10) self._infer_helper(kmeans, clusters, 1) + def _parse_feature_dict_helper(self, features, parsed_feature_dict): + # Perform a sanity check. + self.assertEqual(features.shape, parsed_feature_dict.shape) + self.assertEqual(features.dtype, parsed_feature_dict.dtype) + # Then check that running the tensor yields the original list of points. + with self.test_session() as sess: + parsed_points = sess.run(parsed_feature_dict) + self.assertAllEqual(self.points, parsed_points) + + def test_parse_features(self): + """Tests the various behaviours of kmeans._parse_features_if_necessary.""" + + # No-op if a tensor is passed in. + features = constant_op.constant(self.points) + parsed_features = kmeans_lib._parse_features_if_necessary(features, None) + self.assertAllEqual(features, parsed_features) + + # All values from a feature dict are transformed into a tensor. + feature_dict = { + 'x': [[point[0]] for point in self.points], + 'y': [[point[1]] for point in self.points] + } + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict, None) + self._parse_feature_dict_helper(features, parsed_feature_dict) + + # Only the feature_columns of a feature dict are transformed into a tensor. + feature_dict_with_extras = { + 'foo': 'bar', + 'x': [[point[0]] for point in self.points], + 'baz': {'fizz': 'buzz'}, + 'y': [[point[1]] for point in self.points] + } + feature_columns = [fc.numeric_column(key='x'), fc.numeric_column(key='y')] + parsed_feature_dict = kmeans_lib._parse_features_if_necessary( + feature_dict_with_extras, feature_columns) + self._parse_feature_dict_helper(features, parsed_feature_dict) + class KMeansTestMultiStageInit(KMeansTestBase): @@ -374,7 +413,7 @@ class KMeansCosineDistanceTest(KMeansTestBase): self.assertAllClose(score, self.true_score, atol=1e-2) def test_predict_kmeans_plus_plus(self): - # Most points are concetrated near one center. KMeans++ is likely to find + # Most points are concentrated near one center. KMeans++ is likely to find # the less populated centers. points = np.array( [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], @@ -394,7 +433,6 @@ class KMeansCosineDistanceTest(KMeansTestBase): true_assignments = [0] * 2 + [1] * 2 + [2] * 8 true_score = len(points) - np.tensordot( normalize(points), true_centers[true_assignments]) - kmeans = kmeans_lib.KMeansClustering( 3, initial_clusters=self.initial_clusters, @@ -566,7 +604,7 @@ class KMeansTestQueues(test.TestCase): return _fn # This test makes sure that there are no deadlocks when using a QueueRunner. - # Note that since cluster initialization is dependendent on inputs, if input + # Note that since cluster initialization is dependent on inputs, if input # is generated using a QueueRunner, one has to make sure that these runners # are started before the initialization. def test_queues(self): diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9..ca46c39baa16a7fddb96121e0402fc35d24ce1c2 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -216,7 +216,7 @@ def _wals_factorization_model_function(features, labels, mode, params): name=WALSMatrixFactorization.LOSS, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + # \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\) rwse_var = variable_scope.variable( 0., trainable=False, @@ -235,7 +235,7 @@ def _wals_factorization_model_function(features, labels, mode, params): num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of - * new_factors: A flot Tensor of the factor values after update. + * new_factors: A float Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. @@ -490,11 +490,11 @@ class WALSMatrixFactorization(estimator.Estimator): and the problem simplifies to ALS. Note that, in this case, col_weights must also be set to "None". - List of lists of non-negative scalars, of the form - [[w_0, w_1, ...], [w_k, ... ], [...]], + \\([[w_0, w_1, ...], [w_k, ... ], [...]]\\), where the number of inner lists equal to the number of row factor shards and the elements in each inner list are the weights for the rows of that shard. In this case, - w_ij = unonbserved_weight + row_weights[i] * col_weights[j]. + \\(w_ij = unonbserved_weight + row_weights[i] * col_weights[j]\\). - A non-negative scalar: This value is used for all row weights. Note that it is allowed to have row_weights as a list and col_weights as a scalar, or vice-versa. diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 6fc053759c58d30c24657dd22e7d12be46fc7a7e..aab7d0c9e8874269bfa5f33193b0dc0ba4bbc9cd 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -8,30 +8,47 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "feature_column_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":sequential_feature_column", + ":sequence_feature_column", + "//tensorflow/python:util", ], ) py_library( - name = "sequential_feature_column", - srcs = ["python/feature_column/sequential_feature_column.py"], + name = "sequence_feature_column", + srcs = ["python/feature_column/sequence_feature_column.py"], srcs_version = "PY2AND3", - deps = [], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + ], +) + +py_test( + name = "sequence_feature_column_test", + srcs = ["python/feature_column/sequence_feature_column_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequence_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/contrib/feature_column/__init__.py b/tensorflow/contrib/feature_column/__init__.py index 6da7b126931effae9cc97091a27070d7013450d4..baa8c1567a5aeb39976ab04c54ae2728ba050a7c 100644 --- a/tensorflow/contrib/feature_column/__init__.py +++ b/tensorflow/contrib/feature_column/__init__.py @@ -19,12 +19,18 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.feature_column.python.feature_column.sequential_feature_column import * +from tensorflow.contrib.feature_column.python.feature_column.sequence_feature_column import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + 'sequence_categorical_column_with_hash_bucket', + 'sequence_categorical_column_with_identity', + 'sequence_categorical_column_with_vocabulary_list', + 'sequence_categorical_column_with_vocabulary_file', + 'sequence_input_layer', + 'sequence_numeric_column', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..555beddeaab419bcb23d06f960d370b706d744c8 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -0,0 +1,447 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental methods for tf.feature_column sequence input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import collections + + +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# pylint: disable=protected-access +# TODO(b/73827486): Support SequenceExample. + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + features: A dict mapping keys to tensors. + feature_columns: An iterable of dense sequence columns. Valid columns are + - `embedding_column` that wraps a `sequence_categorical_column_with_*` + - `sequence_numeric_column`. + weight_collections: A list of collection names to which the Variable will be + added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES`. + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc._clean_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, fc._SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'You can wrap a sequence_categorical_column with an embedding_column ' + 'or indicator_column. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + None, default_name='sequence_input_layer', values=features.values()): + builder = fc._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + output_tensors.append( + array_ops.reshape( + dense_tensor, + shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) + fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns) + sequence_length = _assert_all_equal_and_return(sequence_lengths) + return array_ops.concat(output_tensors, -1), sequence_length + + +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + """Returns a feature column that represents sequences of integers. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [watches_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + num_buckets: Range of inputs. Namely, inputs are expected to be in the + range `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace out-of-range inputs. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +def sequence_categorical_column_with_hash_bucket( + key, hash_bucket_size, dtype=dtypes.string): + """A sequence of categorical terms where ids are set by hashing. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + tokens = sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=1000) + tokens_embedding = embedding_column(tokens, dimension=10) + columns = [tokens_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + hash_bucket_size: An int > 1. The number of buckets. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `hash_bucket_size` is not greater than 1. + ValueError: `dtype` is neither string nor integer. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_hash_bucket( + key=key, + hash_bucket_size=hash_bucket_size, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_file( + key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0, + default_value=None, dtype=dtypes.string): + """A sequence of categorical terms where ids use a vocabulary file. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + states = sequence_categorical_column_with_vocabulary_file( + key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, + num_oov_buckets=5) + states_embedding = embedding_column(states, dimension=10) + columns = [states_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_file: The vocabulary file name. + vocabulary_size: Number of the elements in the vocabulary. This must be no + greater than length of `vocabulary_file`, if less than length, later + values are ignored. If None, it is set to the length of `vocabulary_file`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of + the input value. A positive `num_oov_buckets` can not be specified with + `default_value`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + dtype: The type of features. Only string and integer types are supported. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: `vocabulary_file` is missing or cannot be opened. + ValueError: `vocabulary_size` is missing or < 1. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: `dtype` is neither string nor integer. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_file( + key=key, + vocabulary_file=vocabulary_file, + vocabulary_size=vocabulary_size, + num_oov_buckets=num_oov_buckets, + default_value=default_value, + dtype=dtype)) + + +def sequence_categorical_column_with_vocabulary_list( + key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0): + """A sequence of categorical terms where ids use an in-memory list. + + Pass this to `embedding_column` or `indicator_column` to convert sequence + categorical data into dense representation for input to sequence NN, such as + RNN. + + Example: + + ```python + colors = sequence_categorical_column_with_vocabulary_list( + key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), + num_oov_buckets=2) + colors_embedding = embedding_column(colors, dimension=3) + columns = [colors_embedding] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input feature. + vocabulary_list: An ordered iterable defining the vocabulary. Each feature + is mapped to the index of its value (if present) in `vocabulary_list`. + Must be castable to `dtype`. + dtype: The type of features. Only string and integer types are supported. + If `None`, it will be inferred from `vocabulary_list`. + default_value: The integer ID value to return for out-of-vocabulary feature + values, defaults to `-1`. This can not be specified with a positive + `num_oov_buckets`. + num_oov_buckets: Non-negative integer, the number of out-of-vocabulary + buckets. All out-of-vocabulary inputs will be assigned IDs in the range + `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a + hash of the input value. A positive `num_oov_buckets` can not be specified + with `default_value`. + + Returns: + A `_SequenceCategoricalColumn`. + + Raises: + ValueError: if `vocabulary_list` is empty, or contains duplicate keys. + ValueError: `num_oov_buckets` is a negative integer. + ValueError: `num_oov_buckets` and `default_value` are both specified. + ValueError: if `dtype` is not integer or string. + """ + return fc._SequenceCategoricalColumn( + fc.categorical_column_with_vocabulary_list( + key=key, + vocabulary_list=vocabulary_list, + dtype=dtype, + default_value=default_value, + num_oov_buckets=num_oov_buckets)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32): + """Returns a feature column that represents sequences of numeric data. + + Example: + + ```python + temperature = sequence_numeric_column('temperature') + columns = [temperature] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Args: + key: A unique string identifying the input features. + shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`, + each example must contain `2 * sequence_length` values. + default_value: A single value compatible with `dtype` that is used for + padding the sparse data into a dense `Tensor`. + dtype: The type of values. + + Returns: + A `_SequenceNumericColumn`. + + Raises: + TypeError: if any dimension in shape is not an int. + ValueError: if any dimension in shape is not a positive integer. + ValueError: if `dtype` is not convertible to `tf.float32`. + """ + shape = fc._check_shape(shape=shape, key=key) + if not (dtype.is_integer or dtype.is_floating): + raise ValueError('dtype must be convertible to float. ' + 'dtype: {}, key: {}'.format(dtype, key)) + + return _SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype) + + +def _assert_all_equal_and_return(tensors, name=None): + """Asserts that all tensors are equal and returns the first one.""" + with ops.name_scope(name, 'assert_all_equal', values=tensors): + if len(tensors) == 1: + return tensors[0] + assert_equal_ops = [] + for t in tensors[1:]: + assert_equal_ops.append(check_ops.assert_equal(tensors[0], t)) + with ops.control_dependencies(assert_equal_ops): + return array_ops.identity(tensors[0]) + + +class _SequenceNumericColumn( + fc._SequenceDenseColumn, + collections.namedtuple( + '_SequenceNumericColumn', + ['key', 'shape', 'default_value', 'dtype'])): + """Represents sequences of numeric data.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_spec(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.key) + + @property + def _variable_shape(self): + return tensor_shape.TensorShape(self.shape) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + sp_tensor = inputs.get(self) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + sequence_length = fc._sequence_length_from_sparse_tensor( + sp_tensor, num_elements=self._variable_shape.num_elements()) + return fc._SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py new file mode 100644 index 0000000000000000000000000000000000000000..88f5d535162939e063eb1e7f43d495137c5adef4 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -0,0 +1,816 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase): + + def test_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = fc.embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + vocabulary_size_a = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + vocabulary_size_b = 2 + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [1, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 1, 0), + dense_shape=(2, 2)) + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [1, 0] + [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size_a) + indicator_column_a = fc.indicator_column(categorical_column_a) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size_b) + indicator_column_b = fc.indicator_column(categorical_column_b) + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[indicator_column_b, indicator_column_a]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_indicator_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence categorical column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must be of ' + r'type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + def test_numeric_column(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_input_layer = [ + [[0.], [1.]], + [[10.], [0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column_multi_dim(self): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + # The output of numeric_column._get_dense_tensor should be flattened. + expected_input_layer = [ + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_not_equal(self): + """Tests that an error is raised when sequence lengths are not equal.""" + # Input a with sequence_length = [2, 1] + sparse_input_a = sparse_tensor.SparseTensorValue( + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + # Input b with sequence_length = [1, 1] + sparse_input_b = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0)), + values=(1., 10.), + dense_shape=(2, 2)) + numeric_column_a = sfc.sequence_numeric_column('aaa') + numeric_column_b = sfc.sequence_numeric_column('bbb') + + _, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=[numeric_column_a, numeric_column_b]) + + with monitored_session.MonitoredSession() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Condition x == y did not hold element-wise:\] ' + r'\[x \(sequence_input_layer/aaa/sequence_length:0\) = \] \[2 1\] ' + r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'): + sess.run(sequence_length) + + +class InputLayerTest(test.TestCase): + """Tests input_layer with sequence feature columns.""" + + def test_embedding_column(self): + """Tests that error is raised for sequence embedding column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = fc.embedding_column( + categorical_column_a, dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_embedding\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc.input_layer( + features={'aaa': sparse_input}, + feature_columns=[embedding_column_a]) + + def test_indicator_column(self): + """Tests that error is raised for sequence indicator column.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column_a = fc.indicator_column(categorical_column_a) + + with self.assertRaisesRegexp( + ValueError, + r'In indicator_column: aaa_indicator\. categorical_column must not be ' + r'of type _SequenceCategoricalColumn\.'): + _ = fc.input_layer( + features={'aaa': sparse_input}, + feature_columns=[indicator_column_a]) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + _assert_sparse_tensor_indices_shape(test_case, expected, actual) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + +def _assert_sparse_tensor_indices_shape(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_inputs3d(self): + """Tests _get_sparse_tensors when the input is already 3D Tensor.""" + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 2, 0), + dense_shape=(2, 2, 1)) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'Column aaa expected ID tensor of rank 2\.\s*' + r'id_tensor shape:\s*\[2 2 1\]'): + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': inputs})) + with monitored_session.MonitoredSession() as sess: + id_weight_pair.id_tensor.eval(session=sess) + + +class SequenceCategoricalColumnWithHashBucketTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_hash_bucket( + 'aaa', hash_bucket_size=10) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + # Ignored to avoid hash dependence in test. + values=np.array((0, 0, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_indices_shape( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase): + + def _write_vocab(self, vocab_strings, file_name): + vocab_file = os.path.join(self.get_temp_dir(), file_name) + with open(vocab_file, 'w') as f: + f.write('\n'.join(vocab_strings)) + return vocab_file + + def setUp(self): + super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp() + + vocab_strings = ['omar', 'stringer', 'marlo'] + self._wire_vocabulary_file_name = self._write_vocab(vocab_strings, + 'wire_vocabulary.txt') + self._wire_vocabulary_size = 3 + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_vocabulary_file( + key='aaa', + vocabulary_file=self._wire_vocabulary_file_name, + vocabulary_size=self._wire_vocabulary_size) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_vocabulary_list( + key='aaa', + vocabulary_list=('omar', 'stringer', 'marlo')) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((2, -1, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + +class SequenceEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + expected_lookups = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = fc.embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceIndicatorColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + expected_lookups = [ + # example 0, ids [2] + [[0., 0., 1.], [0., 0., 0.]], + # example 1, ids [0, 1] + [[1., 0., 0.], [0., 1., 0.]], + # example 2, ids [] + [[0., 0., 0.], [0., 0., 0.]], + # example 3, ids [1] + [[0., 1., 0.], [0., 0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + indicator_tensor, _ = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + indicator_column = fc.indicator_column(categorical_column) + + _, sequence_length = indicator_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceNumericColumnTest(test.TestCase): + + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + + def test_get_sequence_dense_tensor(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_dense_tensor = [ + [[0.], [1.]], + [[10.], [0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_shape(self): + """Tests get_sequence_dense_tensor with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_dense_tensor = [ + [[0., 1., 2.], [3., 4., 5.]], + [[10., 11., 12.], [0., 0., 0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_dense_tensor_multi_dim(self): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + expected_dense_tensor = [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_sequence_length(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + sequence_length = sess.run(sequence_length) + self.assertAllEqual(expected_sequence_length, sequence_length) + self.assertEqual(np.int64, sequence_length.dtype) + + def test_sequence_length_with_shape(self): + """Tests _sequence_length with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index eccce99071dc1477cf4f3bb152f3304b3b0fc35a..f7b3273a4d35eadb9fad49399b7bf18d4bd33503 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -180,15 +180,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 6b455567d766dbe6d380a498bd7f521db27e077b..59bad8982dd163f89f37e1a0a9d5017d0c495de3 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -74,15 +74,3 @@ tf_cc_test( "//tensorflow/core:test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index e61221a6b0d34373279a379f356c99c379488182..35341406a08dc681c861aea30fcff784e3b963ef 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -256,6 +256,9 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height, if (p != std::string::npos) { string rgb24 = line.substr(p + 9, line.find(" ", p + 9)); rgb24 = rgb24.substr(0, rgb24.find(",")); + // Strip anything after " ", in case the format is + // `640x360 [SAR 1:1 DAR 16:9]` + rgb24 = rgb24.substr(0, rgb24.find(" ")); string rgb24_width = rgb24.substr(0, rgb24.find("x")); string rgb24_height = rgb24.substr(rgb24_width.length() + 1); if (strings::safe_strtou32(rgb24_width, &width_value) && @@ -270,8 +273,10 @@ Status ReadInfoFile(const string& filename, uint32* width, uint32* height, // We only look for the first stream mapping to have the number of the // frames. // Once processed we will not further process stream mapping section. - if (line.find("frame= ") == 0) { - string number = line.substr(8, line.find(" ", 8)); + if (line.find("frame=") == 0) { + // The format might be `frame= 166 ` or `frame=12488 ` + string number = line.substr(6); + number = number.substr(number.find_first_not_of(" ")); number = number.substr(0, number.find(" ")); if (strings::safe_strtou32(number, &frames_value)) { in_mapping = false; diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 9e5f54f0973eae899ca65e4098358107053cb7d4..b1c8ad49eaf8d2400e431fcf4820fca6e0314557 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", - "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -63,7 +62,9 @@ tf_custom_op_py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:smart_cond", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -161,23 +162,6 @@ py_test( ], ) -py_test( - name = "accumulate_n_v2_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - cuda_py_test( name = "critical_section_test", size = "medium", @@ -185,31 +169,14 @@ cuda_py_test( additional_deps = [ "//tensorflow/python:client_testlib", ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:platform_test", "//tensorflow/python:resource_variable_ops", - ], -) - -py_test( - name = "accumulate_n_v2_eager_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_eager_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", - "//third_party/py/numpy", + "//tensorflow/python:tensor_array_ops", ], ) @@ -354,15 +321,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index a49d42cd525434d4ffd4a6bb0d8854dc707b9280..11397e86bd811e376ff4b60f86109d9ee5dd1a62 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -71,6 +71,10 @@ See the @{$python/contrib.framework} guide. @@model_variable @@variable @@VariableDeviceChooser +@@convolutional_delta_orthogonal +@@convolutional_orthogonal_1d +@@convolutional_orthogonal_2d +@@convolutional_orthogonal_3d @@zero_initializer @@load_checkpoint @@ -82,11 +86,16 @@ See the @{$python/contrib.framework} guide. @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer +@@argsort @@py_func @@sort @@get_placeholders +@@smart_cond +@@smart_constant_value +@@smart_case + @@CriticalSection @@BoundedTensorSpec @@ -104,12 +113,18 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope - +from tensorflow.python.framework.smart_cond import smart_case +from tensorflow.python.framework.smart_cond import smart_cond +from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec - +from tensorflow.python.ops.array_ops import broadcast_to +from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal +from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d +from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d +from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest'] +_allowed_symbols = ['nest', 'broadcast_to'] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index 8e54e09e04ee3c0ddbd4fa84cc0912cb70c93e62..cfdc7df7d8fd4c1406bf447a79038ac33b11e047 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -49,7 +49,6 @@ class ExperimentalTest(test.TestCase): "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " "be removed at any time, and without warning." "\n" - "\n" "\nArgs:" "\n arg0: Arg 0." "\n arg1: Arg 1." diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 49eec3a3f1a0f357ea3adfade51e71cb0f89942d..2703224b1bf62831b6088558d4f93950fe938c10 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -85,14 +85,19 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, if n not in reachable_by_input and n not in output_nodes_set: # n is between input and output, i.e., part of the fused op next_to_visit = [n] + visited = set() while next_to_visit: cur_node = next_to_visit[0] + visited.add(cur_node) del next_to_visit[0] if cur_node in reachable_by_input and cur_node not in input_nodes_set: raise TypeError("Node %s uses input %s not in input_nodes." % (n, cur_node)) if cur_node not in input_nodes_set: - next_to_visit += name_to_input_name[cur_node] + next_to_visit += [ + input_node for input_node in name_to_input_name[cur_node] + if input_node not in visited + ] elif n not in reachable_by_input: nodes_post_output.append(n) diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index b8a6d109e19211d271c2b15bac66ddacd38fe395..812c5fbd8cb759aef6eb1aad532c03794b2ceaf4 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -42,7 +42,8 @@ class GraphUtilTest(test.TestCase): graph_def = graph_pb2.GraphDef() node_a = GetNewNode('A', 'Placeholder', []) node_b = GetNewNode('B', 'Op1', ['A']) - node_c = GetNewNode('C', 'Op1', ['B']) + # A loop in the part that will be fused. + node_c = GetNewNode('C', 'Op1', ['B', 'C']) node_d = GetNewNode('D', 'Op1', ['C']) node_e = GetNewNode('E', 'Op1', ['D']) graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 8cdb340f2ddd9b3a7f55c1937ef045f4627e99be..8fc4f60492b0bfb22ea78cb7b5906e452bb6da58 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -48,7 +48,7 @@ class LocalVariabletest(test.TestCase): variables = variables_lib.local_variables() self.assertEquals(2, len(variables)) self.assertRaises(errors_impl.OpError, sess.run, variables) - variables_lib.initialize_variables(variables).run() + variables_lib.variables_initializer(variables).run() self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) @@ -209,6 +209,7 @@ class WithShapeTest(test.TestCase): self.assertRaisesRegexp(errors_impl.OpError, "Wrong shape", tensor_2x2.eval, {tensor_no_shape: [42.0]}) + @test_util.enable_c_shapes def test_with_shape_partial(self): with self.test_session(): tensor_partial_shape = array_ops.placeholder(dtypes.float32) diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py deleted file mode 100644 index 476528b0dd3df05239d5dc402b466e06dd789985..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops - - - -def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): - """Returns the element-wise sum of a list of tensors. - - Optionally, pass `shape` and `tensor_dtype` for shape and type checking, - otherwise, these are inferred. - - `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not - wait for all of its inputs to be ready before beginning to sum. This can - save memory if inputs are ready at different times, since minimum temporary - storage is proportional to the output size rather than the inputs size. - - Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. - - For example: - - ```python - a = tf.constant([[1, 2], [3, 4]]) - b = tf.constant([[5, 0], [0, 6]]) - tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] - - # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) - # [[7, 4], - # [6, 14]] - ``` - - Args: - inputs: A list of `Tensor` objects, each with same shape and type. - shape: Shape of elements of `inputs`. - tensor_dtype: The type of `inputs`. - name: A name for the operation (optional). - - Returns: - A `Tensor` of same shape and type as the elements of `inputs`. - - Raises: - ValueError: If `inputs` don't all have same shape and dtype or the shape - cannot be inferred. - """ - _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" - "with the same dtype and shape") - if not inputs or not isinstance(inputs, (list, tuple)): - raise _INPUTS_ERR_MSG - inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise _INPUTS_ERR_MSG - if not all(x.dtype == inputs[0].dtype for x in inputs): - raise _INPUTS_ERR_MSG - if shape is not None: - shape = tensor_shape.as_shape(shape) - else: - shape = tensor_shape.unknown_shape() - for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): - shape = shape.merge_with(input_tensor.get_shape()) - - # tensor_dtype is for safety only; operator's output type computed in C++ - if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}" - .format(tensor_dtype, inputs[0].dtype)) - - if len(inputs) == 1 and name is None: - return inputs[0] - elif len(inputs) == 1 and name is not None: - return array_ops.identity(inputs[0], name=name) - elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back - # onto AddN for now. - # TODO(frreiss) remove this once the lifetime of eager variables gets - # addressed - return math_ops.add_n(inputs, name=name) - else: - return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) - -# The following code should eventually be merged into -# tensorflow/python/ops/math_grad.py -@ops.RegisterGradient("AccumulateNV2") -def _AddNGrad(op, grad): - """Same as gradient for AddN. Copies the gradient to all inputs.""" - # Not broadcasting. - return [grad] * len(op.inputs) diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 409657fe1da0e5540cd2ad6070d86737c039e91f..5b150339953f961c756c0909dd1795341159b9cd 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -68,7 +68,7 @@ from tensorflow.python.util import tf_decorator __all__ = [ 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', - 'arg_scoped_arguments' + 'arg_scoped_arguments', 'arg_scope_func_key' ] _ARGSTACK = [{}] @@ -89,7 +89,7 @@ def current_arg_scope(): return stack[-1] -def _key_op(op): +def arg_scope_func_key(op): return getattr(op, '_key_op', str(op)) @@ -103,9 +103,9 @@ def _kwarg_names(func): def _add_op(op): - key_op = _key_op(op) - if key_op not in _DECORATED_OPS: - _DECORATED_OPS[key_op] = _kwarg_names(op) + key = arg_scope_func_key(op) + if key not in _DECORATED_OPS: + _DECORATED_OPS[key] = _kwarg_names(op) @tf_contextlib.contextmanager @@ -142,21 +142,21 @@ def arg_scope(list_ops_or_scope, **kwargs): else: # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs. if not isinstance(list_ops_or_scope, (list, tuple)): - raise TypeError('list_ops_or_scope must either be a list/tuple or reused' + raise TypeError('list_ops_or_scope must either be a list/tuple or reused ' 'scope (i.e. dict)') try: current_scope = current_arg_scope().copy() for op in list_ops_or_scope: - key_op = _key_op(op) + key = arg_scope_func_key(op) if not has_arg_scope(op): raise ValueError('%s is not decorated with @add_arg_scope', _name_op(op)) - if key_op in current_scope: - current_kwargs = current_scope[key_op].copy() + if key in current_scope: + current_kwargs = current_scope[key].copy() current_kwargs.update(kwargs) - current_scope[key_op] = current_kwargs + current_scope[key] = current_kwargs else: - current_scope[key_op] = kwargs.copy() + current_scope[key] = kwargs.copy() _get_arg_stack().append(current_scope) yield current_scope finally: @@ -176,14 +176,14 @@ def add_arg_scope(func): def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs - key_func = _key_op(func) + key_func = arg_scope_func_key(func) if key_func in current_scope: current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) _add_op(func) - setattr(func_with_args, '_key_op', _key_op(func)) + setattr(func_with_args, '_key_op', arg_scope_func_key(func)) return tf_decorator.make_decorator(func, func_with_args) @@ -196,7 +196,7 @@ def has_arg_scope(func): Returns: a boolean. """ - return _key_op(func) in _DECORATED_OPS + return arg_scope_func_key(func) in _DECORATED_OPS def arg_scoped_arguments(func): @@ -209,4 +209,4 @@ def arg_scoped_arguments(func): a list of kwargs names. """ assert has_arg_scope(func) - return _DECORATED_OPS[_key_op(func)] + return _DECORATED_OPS[arg_scope_func_key(func)] diff --git a/tensorflow/contrib/framework/python/ops/arg_scope_test.py b/tensorflow/contrib/framework/python/ops/arg_scope_test.py index 7ba9d4ffa90f6860629b15a2ea91e0c573bf6368..4c3879d4fc08b53ea8be5f1256a830a64fb39af6 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope_test.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope_test.py @@ -170,6 +170,30 @@ class ArgScopeTest(test.TestCase): self.assertTupleEqual(args, func1_args) self.assertDictEqual(kwargs, func1_kwargs) + def testNestedArgScopeObjectCreatedOutsideScopeOverridesArgScope(self): + + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + scope_object = get_scope_object() + with arg_scope([func1], b=2, d=10): + with arg_scope(scope_object): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1]}) + + def testArgScopeObjectCreatedWithinScopeInheritsArgScope(self): + def get_scope_object(): + with arg_scope([func1], a=1, b=None, c=[1]) as sc: + return sc + + with arg_scope([func1], b=2, d=10): + with arg_scope(get_scope_object()): + args, kwargs = func1(0) + self.assertTupleEqual(args, (0,)) + self.assertDictEqual(kwargs, {'a': 1, 'b': None, 'c': [1], 'd': 10}) + def testSharedArgScope(self): func1_args = (0,) func1_kwargs = {'a': 1, 'b': None, 'c': [1]} diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 182fec924febb74a23b82b1664d137f033f3b1b4..bd764ed57a6da0a4d356235108e998a80ac34362 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -24,10 +24,12 @@ import collections # from tensorflow.core.protobuf import critical_section_pb2 from tensorflow.python.eager import context -from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -38,11 +40,32 @@ CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" class _ExecutionSignature( collections.namedtuple("_ExecutionSignature", - ("op", "exclusive_resource_access"))): + ("op", "handle", + "resources", "exclusive_resource_access"))): """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" pass +def _identity(x): + """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" + if isinstance(x, tensor_array_ops.TensorArray): + return x.identity() + elif isinstance(x, ops.Operation): + return control_flow_ops.group(x) + elif context.executing_eagerly() and x is None: + return None + else: + return array_ops.identity(x) + + +def _get_colocation(op): + """Get colocation symbol from op, if any.""" + try: + return op.get_attr("_class") + except ValueError: + return None + + class CriticalSection(object): """Critical section. @@ -112,16 +135,18 @@ class CriticalSection(object): ``` """ - def __init__(self, name=None, critical_section_def=None, import_scope=None): + def __init__(self, name=None, shared_name=None, + critical_section_def=None, import_scope=None): """Creates a critical section.""" if critical_section_def and name is not None: - raise ValueError("critical_section_def and name are mutually exclusive.") + raise ValueError("critical_section_def and shared_name are " + "mutually exclusive.") if critical_section_def: self._init_from_proto(critical_section_def, import_scope=import_scope) else: - self._init_from_args(name) + self._init_from_args(name, shared_name) - def _init_from_proto(self, critical_section_def, import_scope): + def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name raise NotImplementedError("Not yet implemented") # TODO(ebrevdo): Re-enable once CriticalSection is in core. # assert isinstance( @@ -133,19 +158,21 @@ class CriticalSection(object): # critical_section_def.critical_section_name, # import_scope=import_scope)) - def _init_from_args(self, name): + def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name """Initialize the CriticalSection from constructor arguments.""" with ops.name_scope(name, "CriticalSection", []) as name: - with ops.control_dependencies(None): + with ops.init_scope(): # pylint: disable=protected-access - handle_name = ops._name_from_scope_name(name) container = ops.get_default_graph()._container # pylint: enable=protected-access + if shared_name is None: + shared_name = name if container is None: container = "" - self._handle = gen_resource_variable_ops.critical_section_op( - shared_name=handle_name, name=name) - if context.in_graph_mode(): + self._handle = gen_resource_variable_ops.mutex_v2( + shared_name=shared_name, container=container, name=name) + + if not context.executing_eagerly(): ops.add_to_collections(CRITICAL_SECTIONS, self) @property @@ -171,8 +198,8 @@ class CriticalSection(object): The tensors returned from `fn(*args, **kwargs)`. Raises: - ValueError: If `fn` attempts to use this `CriticalSection` in any nested - way. + ValueError: If `fn` attempts to lock this `CriticalSection` in any nested + or lazy way that may cause a deadlock. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same resources as in `*args`, `**kwargs`, and any additionaly captured @@ -183,68 +210,163 @@ class CriticalSection(object): name = kwargs.pop("name", None) exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) - args = nest.map_structure(ops.convert_to_tensor, args) with ops.name_scope(name, "critical_section_execute", []): - fn_op = function.make_defun_op(fn, *args, **kwargs) - flat_dtypes = nest.flatten(fn_op.output_dtypes) - flat_shapes = nest.flatten(fn_op.output_shapes) - all_inputs = nest.flatten(args) + fn_op.captured_inputs - if self._handle in all_inputs: - raise ValueError("The function fn attempts to access the " - "CriticalSection in which it would be running. This " - "is illegal and would cause deadlocks. " - "CriticalSection: %s." % self._handle) - - if context.in_graph_mode(): - # Collections and op introspection does not work in eager - # mode. This is generally ok; since eager mode (as of - # writing) executes sequentially anyway. - all_input_resources = [ - x for x in all_inputs if x.dtype == dtypes.resource] - for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): - if sg.op.inputs[0].name == self._handle.name: - # Other executions in the same critical section are allowed. - continue - if not (exclusive_resource_access or sg.exclusive_resource_access): - # Neither execution requested exclusive access. - continue - sg_input_names = [y.name for y in sg.op.inputs[1:]] - for res in all_input_resources: - if res.name in sg_input_names: - raise ValueError( - "This execution would access resource %s; but either this " - "execution (CriticalSection: %s) or Execution '%s' " - "(CriticalSection: %s) requested exclusive resource access " - "of this resource for their critical section. Did you mean " - "to call execute with keyword argument " - "exclusive_resource_access=False?" - % (res.name, - self.name, - sg.op.name, - sg.op.inputs[0].op.name)) - - flat_outputs = gen_resource_variable_ops.execute_in_critical_section( - critical_section=self._handle, - arguments=all_inputs, - f=fn_op, - output_types=flat_dtypes, - output_shapes=flat_shapes) - - if context.in_graph_mode(): - if isinstance(flat_outputs, ops.Operation): - flat_outputs = [flat_outputs] - op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) - else flat_outputs[0]) + + # Ensure that mutex locking only happens *after* all args and + # kwargs have been executed. This avoids certain types of deadlocks. + lock = gen_resource_variable_ops.mutex_lock(self._handle) + + if not context.executing_eagerly(): + # NOTE(ebrevdo): This is to ensure we don't pick up spurious + # Operations created by other threads. + with ops.get_default_graph()._lock: # pylint: disable=protected-access + existing_ops = ops.get_default_graph().get_operations() + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + # TODO(ebrevdo): If creating critical sections in a python loop, this + # makes graph creation time quadratic. Revisit if this + # becomes a problem. + created_ops = (set(ops.get_default_graph().get_operations()) + .difference(existing_ops)) + else: + with ops.control_dependencies([lock]): + r = fn(*args, **kwargs) + + if not context.executing_eagerly(): + self._add_control_dependencies_to_lock(created_ops, lock.op) + + # captured_resources is a list of resources that are directly + # accessed only by ops created during fn(), not by any + # ancestors of those ops in the graph. + captured_resources = set([ + input_ for op in created_ops + for input_ in op.inputs + if input_.dtype == dtypes.resource + ]) + + # NOTE(ebrevdo): The only time self._is_self_handle() is True + # in this call is if one of the recently created ops, within + # the execute(), themselves attempt to access the + # CriticalSection. This will cause a deadlock. + if any(self._is_self_handle(x) for x in captured_resources): + raise ValueError("The function fn attempts to directly access the " + "CriticalSection in which it would be running. " + "This is illegal and would cause deadlocks.") + + self._check_multiple_access_to_resources( + captured_resources, exclusive_resource_access) + + r_flat = [_identity(x) for x in nest.flatten(r)] + + with ops.control_dependencies(r_flat): + # The identity must run on the same machine as self._handle + with ops.colocate_with(self._handle): + # Do not use array_ops.identity as there are special + # optimizations within TensorFlow which seem to elide it + # even when optimizations are disabled(!). + ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( + lock) + + # Make sure that if any element of r is accessed, all of + # them are executed together. + r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) + + with ops.control_dependencies([ensure_lock_exists]): + outputs = nest.map_structure(_identity, r) + + if not context.executing_eagerly(): signature = _ExecutionSignature( - op=op, + op=lock.op, + handle=self._handle, + resources=list(captured_resources), exclusive_resource_access=exclusive_resource_access) ops.add_to_collections( CRITICAL_SECTION_EXECUTIONS, signature) - return (flat_outputs[0] - if (len(flat_outputs) == 1 - and isinstance(flat_outputs[0], ops.Operation)) - else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs)) + return outputs + + def _add_control_dependencies_to_lock(self, created_ops, lock_op): + """To avoid deadlocks, all args must be executed before lock_op.""" + # Get all arguments (explicit and captured) of all ops created by fn(). + all_args = set([input_.op for op in created_ops for input_ in op.inputs]) + all_args.update( + input_op for op in created_ops for input_op in op.control_inputs) + # Unfortunately, we can't use sets throughout because TF seems to + # create new Operation objects for the same op sometimes; and we + # can't rely on id(op). + + # pylint: disable=protected-access + all_args_dict = dict((op._id, op) for op in all_args) + + # Remove ops created within fn, or that lock_op already has a + # control dependency on. Also remove a possible self-loop. + for op in created_ops: + all_args_dict.pop(op._id, None) + for op in lock_op.control_inputs: + all_args_dict.pop(op._id, None) + for input_ in lock_op.inputs: + all_args_dict.pop(input_.op._id, None) + all_args_dict.pop(lock_op._id, None) + + all_args = all_args_dict.values() + + if not all_args: + # No control dependencies to add; return early. + return + + # This group is important: it ensures that any ops in all_args + # outside the control context of the lock_op (and this fn, which + # runs in the same context) are added to this context before + # being added to the control dependencies of lock_op. + all_args = control_flow_ops.group(*all_args) + + lock_op._add_control_input(all_args) + # pylint: enable=protected-access + + def _is_self_handle(self, x): + """Check if the tensor `x` is the same Mutex as `self._handle`.""" + return (x.op.type == "MutexV2" + # blank shared_name means the op will create a unique one. + and x.op.get_attr("shared_name") + and (x.op.get_attr("shared_name") == + self._handle.op.get_attr("shared_name")) + and (x.op.device == self._handle.op.device + or _get_colocation(x.op) == _get_colocation(self._handle.op))) + + def _check_multiple_access_to_resources( + self, captured_resources, exclusive_resource_access): + """Raise if captured_resources are accessed by another CriticalSection. + + Args: + captured_resources: Set of tensors of type resource. + exclusive_resource_access: Whether this execution requires exclusive + resource access. + + Raises: + ValueError: If any tensors in `captured_resources` are also accessed + by another `CriticalSection`, and at least one of them requires + exclusive resource access. + """ + # Collections and op introspection does not work in eager + # mode. This is generally ok; since eager mode (as of + # writing) executes sequentially anyway. + for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): + if self._is_self_handle(sg.handle): + # Other executions in the same critical section are allowed. + continue + if not (exclusive_resource_access or sg.exclusive_resource_access): + # Neither execution requested exclusive access. + continue + resource_intersection = captured_resources.intersection(sg.resources) + if resource_intersection: + raise ValueError( + "This execution would access resources: %s. Either this " + "lock (CriticalSection: %s) or lock '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource. Did you mean to call execute with keyword " + "argument exclusive_resource_access=False?" % + (list(resource_intersection), self._handle.name, + sg.op.name, sg.handle.name)) # TODO(ebrevdo): Re-enable once CriticalSection is in core. @@ -276,6 +398,7 @@ class CriticalSection(object): # def _execution_to_proto_fn(execution_signature, export_scope=None): # """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. +# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. # Args: # execution_signature: Instance of `_ExecutionSignature`. @@ -298,6 +421,7 @@ class CriticalSection(object): # def _execution_from_proto_fn(op_def, import_scope=None): # """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" +# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. # assert isinstance( # op_def, critical_section_pb2.CriticalSectionExecutionDef) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index a416724d3ba1719471d70667e140f9cd2daf86c7..ba660295cb3c97d26da7bf892c78bceee53cf2d4 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -19,14 +19,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import critical_section_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging # TODO(ebrevdo): Re-enable once CriticalSection is in core. # from tensorflow.python.training import saver as saver_lib @@ -35,26 +34,82 @@ class CriticalSectionTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testCreateCriticalSection(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): - c = v.read_value() + c = v.value() with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): return array_ops.identity(c) - num_concurrent = 1000 + num_concurrent = 100 r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value)) + @test_util.run_in_graph_and_eager_modes() + def testCriticalSectionWithControlFlow(self): + for outer_cond in [False, True]: + for inner_cond in [False, True]: + cs = critical_section_ops.CriticalSection(shared_name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + num_concurrent = 100 + + # pylint: disable=cell-var-from-loop + def fn(a, b): + c = v.read_value() + def true_fn(): + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + return control_flow_ops.cond( + array_ops.identity(inner_cond), true_fn, lambda: c) + + def execute(): + return cs.execute(fn, 1.0, 2.0) + + r = [ + control_flow_ops.cond(array_ops.identity(outer_cond), + execute, + v.read_value) + for _ in range(num_concurrent) + ] + # pylint: enable=cell-var-from-loop + + self.evaluate(v.initializer) + r_value = self.evaluate(r) + if inner_cond and outer_cond: + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + else: + self.assertAllClose([0] * num_concurrent, r_value) + + def testCriticalSectionInParallelDoesntDeadlockOnError(self): + # No eager mode execution of this test because eager does not + # run fn() in parallel, which is where the deadlock could + # potentially occur (in graph mode). + cs = critical_section_ops.CriticalSection(shared_name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn(i): + error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) + with ops.control_dependencies([error]): + return v.read_value() + num_concurrent = 2 + r = [cs.execute(fn, i) for i in range(num_concurrent)] + self.evaluate(v.initializer) + for _ in range(100): + with self.assertRaisesOpError("Error"): + self.evaluate(r) + @test_util.run_in_graph_and_eager_modes() def testCreateCriticalSectionFnReturnsOp(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn_return_op(a, b): @@ -62,7 +117,7 @@ class CriticalSectionTest(test.TestCase): with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): - return () + return control_flow_ops.no_op() num_concurrent = 100 r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] @@ -71,52 +126,166 @@ class CriticalSectionTest(test.TestCase): final_v = self.evaluate(v) self.assertAllClose(2.0 * num_concurrent, final_v) - def testCreateCriticalSectionRaw(self): - cs = critical_section_ops.CriticalSection(name="cs") - v = resource_variable_ops.ResourceVariable(0.0, name="v") - - @function.Defun(dtypes.float32, dtypes.float32) - def fn(a, b): - c = v.read_value() - with ops.control_dependencies([c]): - nv = v.assign_add(a * b) - with ops.control_dependencies([nv]): - return array_ops.identity(c) - - def execute(fn, *args): - output_args = fn.definition.signature.output_arg - return resource_variable_ops.execute_in_critical_section( - critical_section=cs._handle, - arguments=list(args) + fn.captured_inputs, - f=fn, - output_types=[out.type for out in output_args], - output_shapes=[tensor_shape.TensorShape(None) for _ in output_args]) - - num_concurrent = 1000 - r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] - self.evaluate(v.initializer) - r_value = self.evaluate(r) - self.assertAllClose([2.0 * i for i in range(num_concurrent)], - sorted(r_value)) - def testCollection(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) - execute_op = cs.execute(lambda x: x + 1, 1.0).op + execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute") + execute_op = [ + x for x in execute.graph.get_operations() + if "my_execute" in x.name and "MutexLock" in x.type + ][0] self.assertIn( execute_op, [signature.op for signature in ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) - @test_util.run_in_graph_and_eager_modes() def testRecursiveCriticalSectionAccessIsIllegal(self): - cs = critical_section_ops.CriticalSection(name="cs") + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection() + def fn(x): + return cs.execute(lambda y: y + 1, x) + with self.assertRaisesRegexp( + ValueError, + r"attempts to directly access the CriticalSection in which it " + r"would be running"): + cs.execute(fn, 1.0) + + def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): + # This one is subtle; and we're being overly cautious here. The + # deadlock we are ensuring we catch is: + # + # to_capture = CS[lambda x: x + 1](1.0) + # deadlocked = CS[lambda x: x + to_capture](1.0) + # + # This would have caused a deadlock because executing `deadlocked` will + # lock the mutex on CS; but then due to dependencies, will attempt + # to compute `to_capture`. This computation requires locking CS, + # but that is not possible now because CS is already locked by + # `deadlocked`. + # + # We check that CriticalSection.execute properly inserts new + # control dependencies to its lock to ensure all captured + # operations are finished before anything runs within the critical section. + cs = critical_section_ops.CriticalSection(shared_name="cs") + fn = array_ops.identity + to_capture = cs.execute(fn, 1.0) + fn_captures = lambda x: x + to_capture + to_capture_too = array_ops.identity(to_capture) + + ex_0 = cs.execute(fn_captures, 1.0) + + with ops.control_dependencies([to_capture]): + # This is OK because to_capture will execute before this next call + ex_1 = cs.execute(fn_captures, 1.0) + + dependency = array_ops.identity(to_capture) + + fn_captures_dependency = lambda x: x + dependency + + ex_2 = cs.execute(fn_captures_dependency, 1.0) + + with ops.control_dependencies([to_capture_too]): + ex_3 = cs.execute(fn_captures_dependency, 1.0) + + # Ensure there's no actual deadlock on to_execute. + self.assertEquals(2.0, self.evaluate(ex_0)) + self.assertEquals(2.0, self.evaluate(ex_1)) + self.assertEquals(2.0, self.evaluate(ex_2)) + self.assertEquals(2.0, self.evaluate(ex_3)) + + def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): + cs = critical_section_ops.CriticalSection(shared_name="cs") + + def body_implicit_capture(i, j): + # This would have caused a deadlock if not for logic in execute + # that inserts additional control dependencies onto the lock op: + # * Loop body argument j is captured by fn() + # * i is running in parallel to move forward the execution + # * j is not being checked by the predicate function + # * output of cs.execute() is returned as next j. + fn = lambda: j + 1 + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture'\n" + "==============\n") + + def body_implicit_capture_protected(i, j): + # This version is ok because we manually add a control + # dependency on j, which is an argument to the while_loop body + # and captured by fn. + fn = lambda: j + 1 + with ops.control_dependencies([j]): + return (i + 1, cs.execute(fn)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_implicit_capture_protected, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_implicit_capture_protected'\n" + "==============\n") + + def body_args_capture(i, j): + # This version is ok because j is an argument to fn and we can + # ensure there's a control dependency on j. + fn = lambda x: x + 1 + return (i + 1, cs.execute(fn, j)) + + (i_n, j_n) = control_flow_ops.while_loop( + lambda i, _: i < 1000, + body_args_capture, + [0, 0], + parallel_iterations=25) + logging.warn( + "\n==============\nRunning " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + self.assertEquals((1000, 1000), self.evaluate((i_n, j_n))) + logging.warn( + "\n==============\nSuccessfully finished running " + "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " + "body_args_capture'\n" + "==============\n") + + def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection(shared_name="cs") + cs_same = critical_section_ops.CriticalSection(shared_name="cs") def fn(x): - return cs.execute(lambda x: x+1, x) + return cs_same.execute(lambda x: x+1, x) with self.assertRaisesRegexp( ValueError, - r"attempts to access the CriticalSection in which it would be running"): + r"attempts to directly access the CriticalSection in which it " + r"would be running"): cs.execute(fn, 1.0) def testMultipleCSExecutionsRequestSameResource(self): @@ -147,6 +316,20 @@ class CriticalSectionTest(test.TestCase): ValueError, "requested exclusive resource access"): cs1.execute(lambda: v2 + 1) + def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0, name="v") + # Make sure that the control dependencies on v do not cause issues + # in the lock_op's automatic control dependency adder. + # + # Note, here v must be a resource variable (or something similar), + # otherwise it gets hoisted into the while_loop by the time we add + # control dependencies to the lock_op. + out = control_flow_ops.while_loop( + lambda i: i < 10, lambda i: cs.execute(lambda j: v + j + 1, i), [0]) + self.evaluate(v.initializer) + self.assertEqual(10, self.evaluate(out)) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): @@ -167,7 +350,7 @@ class CriticalSectionTest(test.TestCase): # self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name) # def testToProto(self): - # cs = critical_section_ops.CriticalSection(name="cs") + # cs = critical_section_ops.CriticalSection(shared_name="cs") # proto = cs.to_proto() # self.assertEqual(proto.critical_section_name, cs._handle.name) # cs_copy = critical_section_ops.CriticalSection.from_proto(proto) diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py index 8f62f0ea7b9b561f235b9496ffda97a9f378d530..1921a77c1e96ee3531d1ed0f98e41c27c9d427ac 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -14,6 +14,7 @@ # ============================================================================== """Support for sorting tensors. +@@argsort @@sort """ @@ -21,6 +22,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -47,64 +51,141 @@ def sort(values, axis=-1, direction='ASCENDING', name=None): ValueError: If axis is not a constant scalar, or the direction is invalid. """ with framework_ops.name_scope(name, 'sort'): - if direction not in _SORT_IMPL: - raise ValueError('%s should be one of %s' % - (direction, ', '.join(sorted(_SORT_IMPL.keys())))) - # Axis must be an integer, not a Tensor. - axis = framework_ops.convert_to_tensor(axis, name='axis') - axis_static = tensor_util.constant_value(axis) - if axis.shape.ndims != 0 or axis_static is None: - raise ValueError('axis must be a constant scalar') - axis_static = int(axis_static) # Avoids NumPy casting error + return _sort_or_argsort(values, axis, direction, return_argsort=False) + + +def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None): + """Returns the indices of a tensor that give its sorted order along an axis. + + For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to + `tf.sort(values)`. For higher dimensions, the output has the same shape as + `values`, but along the given axis, values represent the index of the sorted + element in that slice of the tensor at the given position. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + stable: If True, equal elements in the original tensor will not be + re-ordered in the returned order. Unstable sort is not yet implemented, + but will eventually be the default for performance reasons. If you + require a stable order, pass `stable=True` for forwards compatibility. + name: Optional name for the operation. + + Returns: + An int32 `Tensor` with the same shape as `values`. The indices that would + sort each slice of the given `values` along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + del stable # Unused. + with framework_ops.name_scope(name, 'argsort'): + return _sort_or_argsort(values, axis, direction, return_argsort=True) + + +def _sort_or_argsort(values, axis, direction, return_argsort): + """Internal sort/argsort implementation. + + Args: + values: The input values. + axis: The axis along which to sort. + direction: 'ASCENDING' or 'DESCENDING'. + return_argsort: Whether to return the argsort result. + + Returns: + Either the sorted values, or the indices of the sorted values in the + original tensor. See the `sort` and `argsort` docstrings. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error - values = framework_ops.convert_to_tensor(values, name='values') + values = framework_ops.convert_to_tensor(values, name='values') - return _SORT_IMPL[direction](values, axis_static) + return _SORT_IMPL[direction](values, axis_static, return_argsort) -def _descending_sort(values, axis): +def _descending_sort(values, axis, return_argsort=False): """Sorts values in reverse using `top_k`. Args: values: Tensor of numeric values. axis: Index of the axis which values should be sorted along. + return_argsort: If False, return the sorted values. If True, return the + indices that would sort the values. Returns: The sorted values. """ k = array_ops.shape(values)[axis] rank = array_ops.rank(values) + static_rank = values.shape.ndims # Fast path: sorting the last axis. if axis == -1 or axis + 1 == values.get_shape().ndims: - return nn_ops.top_k(values, k)[0] - - # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. - if axis < 0: - # Make axis a Tensor with the real axis index if needed. - axis += rank - transposition = array_ops.concat( - [ - # Axes up to axis are unchanged. - math_ops.range(axis), - # Swap axis and rank - 1. - [rank - 1], - # Axes in [axis + 1, rank - 1) are unchanged. - math_ops.range(axis + 1, rank - 1), - # Swap axis and rank - 1. - [axis] - ], - axis=0) - top_k_input = array_ops.transpose(values, transposition) - values, unused_indices = nn_ops.top_k(top_k_input, k) - # transposition contains a single cycle of length 2 (swapping 2 elements), - # so it is an involution (it is its own inverse). - return array_ops.transpose(values, transposition) - - -def _ascending_sort(values, axis): + top_k_input = values + transposition = None + else: + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Calculate the actual axis index if counting from the end. Use the static + # rank if available, or else make the axis back into a tensor. + axis += static_rank or rank + if static_rank is not None: + # Prefer to calculate the transposition array in NumPy and make it a + # constant. + transposition = constant_op.constant( + np.r_[ + # Axes up to axis are unchanged. + np.arange(axis), + # Swap axis and rank - 1. + [static_rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + np.arange(axis + 1, static_rank - 1), + # Swap axis and rank - 1. + [axis]], + name='transposition') + else: + # Generate the transposition array from the tensors. + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + + values, indices = nn_ops.top_k(top_k_input, k) + return_value = indices if return_argsort else values + if transposition is not None: + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return_value = array_ops.transpose(return_value, transposition) + return return_value + + +def _ascending_sort(values, axis, return_argsort=False): # Negate the values to get the ascending order from descending sort. - values_or_indices = _descending_sort(-values, axis) - return -values_or_indices + values_or_indices = _descending_sort(-values, axis, return_argsort) + # If not argsort, negate the values again. + return values_or_indices if return_argsort else -values_or_indices _SORT_IMPL = { diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py index d08ae502f10d98ee14d8bea2f76b18bedb935cea..a8fb94b245dccc8c7cf0e94cef9b436f881fe408 100644 --- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -24,6 +24,8 @@ from tensorflow.contrib.framework.python.ops import sort_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -90,6 +92,38 @@ class SortTest(test.TestCase): axis=0, direction='DESCENDING').eval()) + def testSort_staticallyKnownRank_constantTransposition(self): + # The transposition array should be a constant if the rank of "values" is + # statically known. + tensor = random_ops.random_uniform( + # Rank is statically known to be 5, but the dimension lengths are not + # known. + random_ops.random_uniform( + shape=(5,), minval=0, maxval=10, dtype=dtypes.int32)) + sort_ops.sort(tensor, axis=1) + transposition = ( + ops.get_default_graph().get_tensor_by_name('sort/transposition:0')) + self.assertFalse(tensor_util.constant_value(transposition) is None) + self.assertAllEqual( + # Swaps "1" and "4" to put "1" at the end. + tensor_util.constant_value(transposition), + [0, 4, 2, 3, 1]) + + def testArgsort_1d(self): + arr = np.random.random(42) + with self.test_session(): + self.assertAllEqual( + np.sort(arr), + array_ops.gather(arr, sort_ops.argsort(arr)).eval()) + + def testArgsort(self): + arr = np.random.random((5, 6, 7, 8)) + for axis in range(4): + with self.test_session(): + self.assertAllEqual( + np.argsort(arr, axis=axis), + sort_ops.argsort(arr, axis=axis).eval()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index ce37672895b37275770d2f5410f662e9acf1de9d..0eb6889db1fae1c74aeb4392441b308392b091a5 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -157,15 +157,3 @@ cuda_py_test( "requires_cudnn6", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py index a97adf622e6e576f8b5ce2babe004cb3a46d80a5..983b6dc8e5a1512ba81ecbc8d5ca5adaea09afe4 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py @@ -65,7 +65,7 @@ def fused_conv2d_bias_activation(conv_input, side_input_scale: A scalar `float32` that will be multiplied by side_input. This is optional and defaults to 0. side_input: A `Tensor` of the format specified by `data_format`. - This is useful for imlementing ResNet blocks. + This is useful for implementing ResNet blocks. activation_mode: (optional) currently must be the default "Relu". Note that in qint8 mode, it also clips to 127, so acts like ReluX. data_format: Specifies the data format. diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index bb155aa2496cbafd9f0630d3dffb2ba69395186c..3d0ed899322c26bf4ae428930899d7a5885e9f21 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -566,7 +566,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, return Test -def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type): +def CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type): """Calculates the size of an output dimension of a strided convolution. Given the sizes of the corresponding dimension of the input and filter shapes, @@ -827,10 +827,10 @@ class FusedConvInt8Tests(test.TestCase): maxval=1.0, dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) - output_height = CalculateCovolvedOutputDim(input_height, filter_height, - vertical_stride, padding_type) - output_width = CalculateCovolvedOutputDim(input_width, filter_width, - horizontal_stride, padding_type) + output_height = CalculateConvolvedOutputDim(input_height, filter_height, + vertical_stride, padding_type) + output_width = CalculateConvolvedOutputDim(input_width, filter_width, + horizontal_stride, padding_type) print("output_height=", output_height, ", output_width=", output_width) side_input, _, _ = gen_array_ops.quantize_v2( diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 5db34f0f8db93620b8b4a6b71f63b66ac718ee30..b305f37791d71f5a6edeada2bb710a2e5f23087d 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -55,6 +55,7 @@ py_test( name = "train_test", srcs = ["python/train_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":features", ":namedtuples", @@ -353,6 +354,7 @@ py_test( name = "classifier_metrics_test", srcs = ["python/eval/python/classifier_metrics_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":classifier_metrics", "//tensorflow/core:protos_all_py", @@ -362,6 +364,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:variables", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -543,15 +546,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 082c42eba180917e732bb7890129dfa94bf00fec..e3fc6bf0f034051fc33ff5966e2f4ea85aa538db 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -88,8 +88,8 @@ class GANEstimator(estimator.Estimator): discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5), - discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5)) + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) # Train estimator. gan_estimator.train(train_input_fn, steps) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index fdfabd07c13f689d075ecbb8786d725fa8a62d01..47e51415fd9e7daa360ca06a11078f6edcf63b5b 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -44,11 +44,11 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import resource_loader - __all__ = [ 'get_graph_def_from_disk', 'get_graph_def_from_resource', @@ -62,10 +62,11 @@ __all__ = [ 'frechet_inception_distance', 'frechet_classifier_distance', 'frechet_classifier_distance_from_activations', + 'mean_only_frechet_classifier_distance_from_activations', + 'diagonal_only_frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] - INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' INCEPTION_INPUT = 'Mul:0' @@ -77,8 +78,7 @@ INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): images = ops.convert_to_tensor(images) images.shape.with_rank(4) - images.shape.assert_is_compatible_with( - [None, image_size, image_size, None]) + images.shape.assert_is_compatible_with([None, image_size, image_size, None]) return images @@ -109,9 +109,10 @@ def _symmetric_matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -def preprocess_image( - images, height=INCEPTION_DEFAULT_IMAGE_SIZE, - width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): +def preprocess_image(images, + height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, + scope=None): """Prepare a batch of images for evaluation. This is the preprocessing portion of the graph from @@ -272,8 +273,11 @@ def run_inception(images, return activations -def run_image_classifier(tensor, graph_def, input_tensor, - output_tensor, scope='RunClassifier'): +def run_image_classifier(tensor, + graph_def, + input_tensor, + output_tensor, + scope='RunClassifier'): """Runs a network from a frozen graph. Args: @@ -317,7 +321,7 @@ def classifier_score(images, classifier_fn, num_batches=1): NOTE: This function consumes images, computes their logits, and then computes the classifier score. If you would like to precompute many logits for - large batches, use clasifier_score_from_logits(), which this method also + large batches, use classifier_score_from_logits(), which this method also uses. Args: @@ -433,8 +437,8 @@ def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above - sqrt_a_sigmav_a = math_ops.matmul( - sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) + sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, + math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) @@ -450,9 +454,9 @@ def frechet_classifier_distance(real_images, This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates + C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the @@ -463,7 +467,7 @@ def frechet_classifier_distance(real_images, Frechet distance is biased. It is more biased for small sample sizes. (e.g. even if the two distributions are the same, for a small sample size, the expected Frechet distance is large). It is important to use the same - sample size to compute frechet classifier distance when comparing two + sample size to compute Frechet classifier distance when comparing two generative models. NOTE: This function consumes images, computes their activations, and then @@ -511,10 +515,142 @@ def frechet_classifier_distance(real_images, return frechet_classifier_distance_from_activations(real_a, gen_a) -def frechet_classifier_distance_from_activations( +def mean_only_frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model from activations. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + In this variant, we only compute the difference between the means of the + fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet + still retains much of the same information as FID. + + Args: + real_activations: 2D array of activations of real images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + generated_activations: 2D array of activations of generated images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + + Returns: + The mean-only Frechet Inception distance. A floating-point scalar of the + same type as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute means of activations. + m = math_ops.reduce_mean(real_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + mofid = mean + if activations_dtype != dtypes.float64: + mofid = math_ops.cast(mofid, activations_dtype) + + return mofid + + +def diagonal_only_frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. In this variant, we compute diagonal-only covariance matrices. + As a result, instead of computing an expensive matrix square root, we can do + something much simpler, and has O(n) vs O(n^2) space complexity. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The diagonal-only Frechet Inception distance. A floating-point scalar of + the same type as the output of the activations. + + Raises: + ValueError: If the shape of the variance and mean vectors are not equal. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute mean and covariance matrices of activations. + m, var = nn_impl.moments(real_activations, axes=[0]) + m_w, var_w = nn_impl.moments(generated_activations, axes=[0]) + + actual_shape = var.get_shape() + expected_shape = m.get_shape() + + if actual_shape != expected_shape: + raise ValueError('shape: {} must match expected shape: {}'.format( + actual_shape, expected_shape)) + + # Compute the two components of FID. + + # First the covariance component. + # Here, note that trace(A + B) = trace(A) + trace(B) + trace = math_ops.reduce_sum( + (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) + + # Next the distance between means. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. + dofid = trace + mean + if activations_dtype != dtypes.float64: + dofid = math_ops.cast(dofid, activations_dtype) + + return dofid + + +def frechet_classifier_distance_from_activations(real_activations, + generated_activations): + """Classifier distance for evaluating a generative model. + This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large @@ -523,15 +659,22 @@ def frechet_classifier_distance_from_activations( This technique is described in detail in https://arxiv.org/abs/1706.08500. Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates + C and C_w, this function calculates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. @@ -553,36 +696,38 @@ def frechet_classifier_distance_from_activations( # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) - m_v = math_ops.reduce_mean(generated_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( - real_centered, real_centered, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / ( + num_examples - 1) - gen_centered = generated_activations - m_v - sigma_v = math_ops.matmul( - gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_w + sigma_w = math_ops.matmul( + gen_centered, gen_centered, transpose_a=True) / ( + num_examples - 1) - # Find the Tr(sqrt(sigma sigma_v)) component of FID - sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) + # Find the Tr(sqrt(sigma sigma_w)) component of FID + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component + trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid - frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 61dc8646ddc10605561ae6b19e90f4739c346608..4fb8d58bc9125664d42260de72b83b2362eff9ba 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -22,6 +22,7 @@ import os import tarfile import tempfile +from absl.testing import parameterized import numpy as np from scipy import linalg as scp_linalg @@ -50,6 +51,26 @@ def _expected_inception_score(logits): return np.exp(np.mean(per_example_logincscore)) +def _expected_mean_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + mean = np.square(m - m_v).sum() + mofid = mean + return mofid + + +def _expected_diagonal_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + var = np.var(real_imgs, axis=0) + var_v = np.var(gen_imgs, axis=0) + sqcc = np.sqrt(var * var_v) + mean = (np.square(m - m_v)).sum() + trace = (var + var_v - 2 * sqcc).sum() + dofid = mean + trace + return dofid + + def _expected_fid(real_imgs, gen_imgs): m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) @@ -162,13 +183,20 @@ def _run_with_mock(function, *args, **kwargs): return function(*args, **kwargs) -class ClassifierMetricsTest(test.TestCase): +class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): - def test_run_inception_graph(self): + @parameterized.named_parameters( + ('GraphDef', False), + ('DefaultGraphDefFn', True)) + def test_run_inception_graph(self, use_default_graph_def): """Test `run_inception` graph construction.""" batch_size = 7 img = array_ops.ones([batch_size, 299, 299, 3]) - logits = _run_with_mock(classifier_metrics.run_inception, img) + + if use_default_graph_def: + logits = _run_with_mock(classifier_metrics.run_inception, img) + else: + logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) self.assertTrue(isinstance(logits, ops.Tensor)) logits.shape.assert_is_compatible_with([batch_size, 1001]) @@ -176,14 +204,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) - def test_run_inception_graph_pool_output(self): + @parameterized.named_parameters( + ('GraphDef', False), + ('DefaultGraphDefFn', True)) + def test_run_inception_graph_pool_output(self, use_default_graph_def): """Test `run_inception` graph construction with pool output.""" batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) - pool = _run_with_mock( - classifier_metrics.run_inception, - img, - output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) + + if use_default_graph_def: + pool = _run_with_mock( + classifier_metrics.run_inception, + img, + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) + else: + pool = classifier_metrics.run_inception( + img, _get_dummy_graphdef(), + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -285,6 +322,46 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(_expected_inception_score(logits), incscore_np) + def test_mean_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_mofid = sess.run(mofid_op) + + expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_mofid, actual_mofid, 0.0001) + + def test_diagonal_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_dofid = sess.run(dofid_op) + + expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_dofid, actual_dofid, 0.0001) + def test_frechet_classifier_distance_value(self): """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 9bebcacbe46d85fc4226c4275b71b3ecbde57a97..4b1105f6bd4f21a0da02338b0fc9db87a41b145f 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): proj = random_ops.random_normal( [array_ops.shape(a)[1], random_projection_dim]) proj *= math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True)) # Project both distributions and sort them. proj_a = math_ops.matmul(a, proj) proj_b = math_ops.matmul(b, proj) @@ -212,7 +212,7 @@ def sliced_wasserstein_distance(real_images, Args: real_images: (tensor) Real images (batch, height, width, channels). fake_images: (tensor) Fake images (batch, height, width, channels). - resolution_min: (int) Minimum resolution for the Laplacion pyramid. + resolution_min: (int) Minimum resolution for the Laplacian pyramid. patches_per_image: (int) Number of patches to extract per image per Laplacian level. patch_size: (int) Width of a square patch. @@ -221,7 +221,7 @@ def sliced_wasserstein_distance(real_images, use_svd: experimental method to compute a more accurate distance. Returns: List of tuples (distance_real, distance_fake) for each level of the - Laplacian pyramid from the highest resoluion to the lowest. + Laplacian pyramid from the highest resolution to the lowest. distance_real is the Wasserstein distance between real images distance_fake is the Wasserstein distance between real and fake images. Raises: diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 0d1afad72da8a8e087239868e25ddebe23490d1e..508f487722fba89cc8391a340f73673a526e86c4 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -31,6 +31,7 @@ __all__ = [ 'add_image_comparison_summaries', 'add_gan_model_summaries', 'add_regularization_loss_summaries', + 'add_cyclegan_image_summaries', ] @@ -51,14 +52,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): ValueError: If real and generated data aren't images. """ if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_summaries'): - add_gan_model_image_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_summaries'): - add_gan_model_image_summaries(gan_model.model_y2x, **saved_params) - return - + raise ValueError( + '`add_gan_model_image_summaries` does not take CycleGANModels. Please ' + 'use `add_cyclegan_image_summaries` instead.') _assert_is_image(gan_model.real_data) _assert_is_image(gan_model.generated_data) @@ -89,6 +85,49 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): add_gan_model_summaries(gan_model) +def add_cyclegan_image_summaries(cyclegan_model): + """Adds image summaries for CycleGAN. + + There are two summaries, one for each generator. The first image is the + generator input, the second is the generator output, and the third is G(F(x)). + + Args: + cyclegan_model: A CycleGANModel tuple. + + Raises: + ValueError: If `cyclegan_model` isn't a CycleGANModel. + ValueError: If generated data, generator inputs, and reconstructions aren't + images. + ValueError: If the generator input, generated data, and reconstructions + aren't all the same size. + """ + if not isinstance(cyclegan_model, namedtuples.CycleGANModel): + raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was ' + '%s' % type(cyclegan_model)) + + _assert_is_image(cyclegan_model.model_x2y.generator_inputs) + _assert_is_image(cyclegan_model.model_x2y.generated_data) + _assert_is_image(cyclegan_model.reconstructed_x) + _assert_is_image(cyclegan_model.model_y2x.generator_inputs) + _assert_is_image(cyclegan_model.model_y2x.generated_data) + _assert_is_image(cyclegan_model.reconstructed_y) + + def _add_comparison_summary(gan_model, reconstructions): + image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) + + array_ops.unstack(gan_model.generated_data[:1]) + + array_ops.unstack(reconstructions[:1])) + summary.image( + 'image_comparison', eval_utils.image_reshaper( + image_list, num_cols=len(image_list)), max_outputs=1) + + with ops.name_scope('x2y_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_x2y, cyclegan_model.reconstructed_x) + with ops.name_scope('y2x_image_comparison_summaries'): + _add_comparison_summary( + cyclegan_model.model_y2x, cyclegan_model.reconstructed_y) + + def add_image_comparison_summaries(gan_model, num_comparisons=2, display_diffs=False): """Adds image summaries to compare triplets of images. @@ -109,15 +148,6 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, ValueError: If the generator input, real, and generated data aren't all the same size. """ - if isinstance(gan_model, namedtuples.CycleGANModel): - saved_params = locals() - saved_params.pop('gan_model', None) - with ops.name_scope('cyclegan_x2y_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_x2y, **saved_params) - with ops.name_scope('cyclegan_y2x_image_comparison_summaries'): - add_image_comparison_summaries(gan_model.model_y2x, **saved_params) - return - _assert_is_image(gan_model.generator_inputs) _assert_is_image(gan_model.generated_data) _assert_is_image(gan_model.real_data) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 7956db43348c0cc0f3d372e92a2e343f5aa62013..33d51bfc218ab93fb52439b1eefed98a4568c4a1 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -65,15 +65,14 @@ def get_cyclegan_model(): return namedtuples.CycleGANModel( model_x2y=model_x2y, model_y2x=model_y2x, - reconstructed_x=array_ops.zeros([3, 30, 35, 6]), - reconstructed_y=array_ops.zeros([3, 30, 35, 6])) + reconstructed_x=array_ops.zeros([4, 32, 32, 3]), + reconstructed_y=array_ops.zeros([4, 32, 32, 3])) class SummariesTest(test.TestCase): - def _test_add_gan_model_image_summaries_impl(self, get_model_fn, - expected_num_summary_ops, - model_summaries): + def _test_add_gan_model_image_summaries_impl( + self, get_model_fn, expected_num_summary_ops, model_summaries): summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, model_summaries=model_summaries) @@ -89,9 +88,9 @@ class SummariesTest(test.TestCase): def test_add_gan_model_image_summaries_no_model(self): self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) - def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, - True) + def test_cyclegan_image_summaries_dont_work(self): + with self.assertRaises(ValueError): + summaries.add_gan_model_image_summaries(get_cyclegan_model()) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): @@ -138,7 +137,11 @@ class SummariesTest(test.TestCase): self._test_add_image_comparison_summaries_impl(get_gan_model, 1) def test_add_image_comparison_summaries_for_cyclegan(self): - self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2) + summaries.add_cyclegan_image_summaries(get_cyclegan_model()) + + self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + with self.test_session(use_gpu=True): + summary.merge_all().eval() if __name__ == '__main__': diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py index cd31c62667fc048b1003d334377405b284f32af5..e2594faf85bcf91cbe09f266e4d4211d20bdee17 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples. +"""Miscellaneous utilities for TFGAN code and examples. Includes: 1) Conditioning the value of a Tensor, based on techniques from diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 4cfae0de4451880cf8229903b0eb74b1c6e2e04d..9e4ec59e7098443efc53506a4ba159e84b5c1618 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -17,7 +17,7 @@ We use this to keep a history of values created by a generator, such that a discriminator can randomly be trained on some older samples, not just the current one. This can help to not let the discriminator get too far ahead of the -generator and also to keep the system from oscilating, if the discriminator +generator and also to keep the system from oscillating, if the discriminator forgets too fast what past samples from the generator looked like. See the following papers for more details. @@ -97,7 +97,7 @@ def tensor_pool(input_values, dtypes=[v.dtype for v in input_values], shapes=None) - # In pseudeo code this code does the following: + # In pseudo code this code does the following: # if not pool_full: # enqueue(input_values) # return input_values diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f8b372546b60ec8fa5fd1d72b57adaf67596c059..650eab97a3952e9aec2b489fffcc83c3bc49f2dd 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -64,11 +64,11 @@ def _statistics(x, axes): y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True)) + shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True) + shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True) + mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) mean = array_ops.squeeze(mean, axes) mean_squared = array_ops.squeeze(mean_squared, axes) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py index 845f89827b6e60eda41a55a80671f43460247b05..2fe06a287284ff994326d5a977a2e4d4634268ae 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -148,7 +148,7 @@ class VirtualBatchnormTest(test.TestCase): self.assertAllClose(bn_np[i, ...], vb_np) def test_minibatch_independent(self): - """Test that virtual batch normalized exampels are independent. + """Test that virtual batch normalized examples are independent. Unlike batch normalization, virtual batch normalization has the property that the virtual batch normalized value of an example is independent of the diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 39588b7219ebac1cc4855532be3fcc38e6381134..1ba3a641671c7f2a411a0c5f99228ca16eee1080 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -306,6 +306,7 @@ def wasserstein_gradient_penalty( discriminator_scope, epsilon=1e-10, target=1.0, + one_sided=False, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -327,6 +328,8 @@ def wasserstein_gradient_penalty( computing the gradient norm. target: Optional Python number or `Tensor` indicating the target value of gradient norm. Defaults to 1.0. + one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 + is used. Defaults to `False`. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -377,10 +380,13 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes / target - 1.0) + penalties = slopes / target - 1.0 + if one_sided: + penalties = math_ops.maximum(0., penalties) + penalties_squared = math_ops.square(penalties) penalty = losses.compute_weighted_loss( - penalties, weights, scope=scope, loss_collection=loss_collection, - reduction=reduction) + penalties_squared, weights, scope=scope, + loss_collection=loss_collection, reduction=reduction) if add_summaries: summary.scalar('gradient_penalty_loss', penalty) @@ -665,7 +671,7 @@ def least_squares_discriminator_loss( loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): - """Least squares generator loss. + """Least squares discriminator loss. This loss comes from `Least Squares Generative Adversarial Networks` (https://arxiv.org/abs/1611.04076). diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index dbaa624ae9d6a5a5949db692e52c0c1deb18b8df..2889e937436d2faa66b5693c19046e122cbaf652 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,28 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_using_one_sided_mode(self): + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + one_sided=True) + self.assertEqual(generated_data.dtype, loss.dtype) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run(loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): """Test loss value with non default gradient norm target.""" generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 776eb11ecb1624544d24611d8fe6ca19768b8313..6fa43059f3125daea080f780210223363d0a89f9 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -461,6 +461,7 @@ def gan_loss( gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, gradient_penalty_target=1.0, + gradient_penalty_one_sided=False, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, @@ -485,6 +486,8 @@ def gan_loss( gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python number or `Tensor` indicating the target value of gradient norm. See the CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. + gradient_penalty_one_sided: If `True`, penalty proposed in + https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -546,6 +549,7 @@ def gan_loss( model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, + one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): @@ -706,7 +710,10 @@ def gan_train_ops( be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): - saved_params = locals() + # Get and store all arguments other than model and loss from locals. + # Contents of locals should not be modified, may not affect values. So make + # a copy. https://docs.python.org/2/library/functions.html#locals. + saved_params = dict(locals()) saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index f9bdaa74c948ecee11d5cfd89f06087924f8dace..3ebbe55d059e5e72607bc4efdbf95a6c96d99f11 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -359,10 +359,12 @@ class GANLossTest(test.TestCase): self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) # Test gradient penalty option. - def _test_grad_penalty_helper(self, create_gan_model_fn): + def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False): model = create_gan_model_fn() loss = train.gan_loss(model) - loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0) + loss_gp = train.gan_loss(model, + gradient_penalty_weight=1.0, + gradient_penalty_one_sided=one_sided) self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) # Check values. @@ -394,6 +396,25 @@ class GANLossTest(test.TestCase): def test_grad_penalty_callable_acgan(self): self._test_grad_penalty_helper(create_callable_acgan_model) + def test_grad_penalty_one_sided_gan(self): + self._test_grad_penalty_helper(create_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_gan(self): + self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_infogan(self): + self._test_grad_penalty_helper(create_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_infogan(self): + self._test_grad_penalty_helper( + create_callable_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_acgan(self): + self._test_grad_penalty_helper(create_acgan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_acgan(self): + self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True) + # Test mutual information penalty option. def _test_mutual_info_penalty_helper(self, create_gan_model_fn): train.gan_loss(create_gan_model_fn(), diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 707ae25d485c64f15694ee0e357f32b619d3cd33..e534fdc17749974ebe713c2730682bea6d7a85e4 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -9,18 +9,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - filegroup( name = "c_srcs", data = glob([ diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index 967ad2fc090906e93f22c777816eede37f9a1b04..1711100e3a857dba0d15c5b4f6c96cddc568e800 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -39,18 +39,6 @@ py_library( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_library( name = "match", srcs = ["tests/match.py"], diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 7ffdbb7139281734917fdb715601b317eb58b82f..95c02a64d47c26e731ef2628fb551529e9bc3f4d 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -471,9 +471,10 @@ def remove_control_inputs(op, cops): if cop not in op.control_inputs: raise ValueError("{} is not a control_input of {}".format(op.name, cop.name)) + control_inputs = [cop for cop in op.control_inputs if cop not in cops] # pylint: disable=protected-access - op._control_inputs = [cop for cop in op._control_inputs if cop not in cops] - op._recompute_node_def() + op._remove_all_control_inputs() + op._add_control_inputs(control_inputs) # pylint: enable=protected-access @@ -496,9 +497,6 @@ def add_control_inputs(op, cops): if cop in op.control_inputs: raise ValueError("{} is already a control_input of {}".format(cop.name, op.name)) - # pylint: disable=protected-access - op._control_inputs += cops - op._recompute_node_def() - # pylint: enable=protected-access + op._add_control_inputs(cops) # pylint: disable=protected-access remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index 3ea6ff4d6163b107ca0daaf3b9ad1daf0ccc1f6f..d700e6e1a7523622f845acbbc353eb0f438c9bc2 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -383,6 +383,7 @@ def get_within_boundary_ops(ops, def get_forward_walk_ops(seed_ops, inclusive=True, within_ops=None, + within_ops_fn=None, stop_at_ts=(), control_outputs=None): """Do a forward graph walk and return all the visited ops. @@ -395,6 +396,9 @@ def get_forward_walk_ops(seed_ops, within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_outputs: a `util.ControlOutputs` instance or None. If not `None`, it will be used while walking the graph forward. @@ -423,7 +427,8 @@ def get_forward_walk_ops(seed_ops, seed_ops &= within_ops def is_within(op): - return within_ops is None or op in within_ops + return (within_ops is None or op in within_ops) and ( + within_ops_fn is None or within_ops_fn(op)) result = list(seed_ops) wave = set(seed_ops) @@ -450,6 +455,7 @@ def get_forward_walk_ops(seed_ops, def get_backward_walk_ops(seed_ops, inclusive=True, within_ops=None, + within_ops_fn=None, stop_at_ts=(), control_inputs=False): """Do a backward graph walk and return all the visited ops. @@ -462,6 +468,9 @@ def get_backward_walk_ops(seed_ops, within_ops: an iterable of `tf.Operation` within which the search is restricted. If `within_ops` is `None`, the search is performed within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. stop_at_ts: an iterable of tensors at which the graph walk stops. control_inputs: if True, control inputs will be used while moving backward. Returns: @@ -488,7 +497,8 @@ def get_backward_walk_ops(seed_ops, seed_ops &= within_ops def is_within(op): - return within_ops is None or op in within_ops + return (within_ops is None or op in within_ops) and ( + within_ops_fn is None or within_ops_fn(op)) result = list(seed_ops) wave = set(seed_ops) @@ -516,6 +526,7 @@ def get_walks_intersection_ops(forward_seed_ops, forward_inclusive=True, backward_inclusive=True, within_ops=None, + within_ops_fn=None, control_inputs=False, control_outputs=None, control_ios=None): @@ -535,6 +546,9 @@ def get_walks_intersection_ops(forward_seed_ops, within_ops: an iterable of tf.Operation within which the search is restricted. If within_ops is None, the search is performed within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. @@ -555,11 +569,13 @@ def get_walks_intersection_ops(forward_seed_ops, forward_seed_ops, inclusive=forward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_outputs=control_outputs) backward_ops = get_backward_walk_ops( backward_seed_ops, inclusive=backward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_inputs=control_inputs) return [op for op in forward_ops if op in backward_ops] @@ -569,6 +585,7 @@ def get_walks_union_ops(forward_seed_ops, forward_inclusive=True, backward_inclusive=True, within_ops=None, + within_ops_fn=None, control_inputs=False, control_outputs=None, control_ios=None): @@ -587,6 +604,9 @@ def get_walks_union_ops(forward_seed_ops, resulting set. within_ops: restrict the search within those operations. If within_ops is None, the search is done within the whole graph. + within_ops_fn: if provided, a function on ops that should return True iff + the op is within the graph traversal. This can be used along within_ops, + in which case an op is within if it is also in within_ops. control_inputs: A boolean indicating whether control inputs are enabled. control_outputs: An instance of util.ControlOutputs or None. If not None, control outputs are enabled. @@ -607,11 +627,13 @@ def get_walks_union_ops(forward_seed_ops, forward_seed_ops, inclusive=forward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_outputs=control_outputs) backward_ops = get_backward_walk_ops( backward_seed_ops, inclusive=backward_inclusive, within_ops=within_ops, + within_ops_fn=within_ops_fn, control_inputs=control_inputs) return util.concatenate_unique(forward_ops, backward_ops) diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py index 82f999637d0c1866a5a329974f021fe2e30fd33f..d12c6d3cbd11dde2b609a59154297a8907b0cadc 100644 --- a/tensorflow/contrib/graph_editor/tests/select_test.py +++ b/tensorflow/contrib/graph_editor/tests/select_test.py @@ -77,12 +77,10 @@ class SelectTest(test.TestCase): """Test for ge.get_ops_ios.""" control_outputs = ge.util.ControlOutputs(self.graph) self.assertEqual( - len(ge.get_ops_ios( - self.h.op, control_ios=control_outputs)), 3) + len(ge.get_ops_ios(self.h.op, control_ios=control_outputs)), 3) self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2) self.assertEqual( - len(ge.get_ops_ios( - self.c.op, control_ios=control_outputs)), 6) + len(ge.get_ops_ios(self.c.op, control_ios=control_outputs)), 6) self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5) def test_compute_boundary_ts_0(self): @@ -135,16 +133,49 @@ class SelectTest(test.TestCase): ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op]) self.assertEqual(len(ops), 2) + ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op]) + self.assertEqual(len(ops), 3) + self.assertTrue(self.a.op in ops) + self.assertTrue(self.c.op in ops) + self.assertTrue(self.f.op in ops) + + within_ops = [self.a.op, self.f.op] + ops = ge.get_walks_intersection_ops( + [self.a.op], [self.f.op], within_ops=within_ops) + self.assertEqual(len(ops), 0) + + within_ops_fn = lambda op: op in [self.a.op, self.f.op] + ops = ge.get_walks_intersection_ops( + [self.a.op], [self.f.op], within_ops_fn=within_ops_fn) + self.assertEqual(len(ops), 0) + def test_get_walks_union(self): """Test for ge.get_walks_union_ops.""" ops = ge.get_walks_union_ops([self.f.op], [self.g.op]) self.assertEqual(len(ops), 6) + ops = ge.get_walks_union_ops([self.a.op], [self.f.op]) + self.assertEqual(len(ops), 8) + + within_ops = [self.a.op, self.c.op, self.d.op, self.f.op] + ops = ge.get_walks_union_ops([self.a.op], [self.f.op], + within_ops=within_ops) + self.assertEqual(len(ops), 4) + self.assertTrue(self.b.op not in ops) + + within_ops_fn = lambda op: op in [self.a.op, self.c.op, self.f.op] + ops = ge.get_walks_union_ops([self.a.op], [self.f.op], + within_ops_fn=within_ops_fn) + self.assertEqual(len(ops), 3) + self.assertTrue(self.b.op not in ops) + self.assertTrue(self.d.op not in ops) + def test_select_ops(self): parameters = ( (("^foo/",), 7), (("^foo/bar/",), 4), - (("^foo/bar/", "a"), 5),) + (("^foo/bar/", "a"), 5), + ) for param, length in parameters: ops = ge.select_ops(*param, graph=self.graph) self.assertEqual(len(ops), length) @@ -152,7 +183,8 @@ class SelectTest(test.TestCase): def test_select_ts(self): parameters = ( (".*:0", 8), - (r".*/bar/\w+:0", 4),) + (r".*/bar/\w+:0", 4), + ) for regex, length in parameters: ts = ge.select_ts(regex, graph=self.graph) self.assertEqual(len(ts), length) @@ -160,12 +192,121 @@ class SelectTest(test.TestCase): def test_select_ops_and_ts(self): parameters = ( (("^foo/.*",), 7, 0), - (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),) + (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4), + ) for param, l0, l1 in parameters: ops, ts = ge.select_ops_and_ts(*param, graph=self.graph) self.assertEqual(len(ops), l0) self.assertEqual(len(ts), l1) + def test_forward_walk_ops(self): + seed_ops = [self.a.op, self.d.op] + # Include all ops except for self.g.op + within_ops = [ + x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] + ] + # For the fn, exclude self.e.op. + within_ops_fn = lambda op: op not in (self.e.op,) + stop_at_ts = (self.f,) + + with self.graph.as_default(): + # No b.op since it's an independent source node. + # No g.op from within_ops. + # No e.op from within_ops_fn. + # No h.op from stop_at_ts and within_ops. + ops = ge.select.get_forward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual( + set(ops), set([self.a.op, self.c.op, self.d.op, self.f.op])) + + # Also no a.op and d.op when inclusive=False + ops = ge.select.get_forward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.c.op, self.f.op])) + + # Not using within_ops_fn adds e.op. + ops = ge.select.get_forward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.c.op, self.e.op, self.f.op])) + + # Not using stop_at_ts adds back h.op. + ops = ge.select.get_forward_walk_ops( + seed_ops, inclusive=False, within_ops=within_ops) + self.assertEqual( + set(ops), set([self.c.op, self.e.op, self.f.op, self.h.op])) + + # Starting just form a (the tensor, not op) omits a, b, d. + ops = ge.select.get_forward_walk_ops([self.a], inclusive=True) + self.assertEqual( + set(ops), set([self.c.op, self.e.op, self.f.op, self.g.op, + self.h.op])) + + def test_backward_walk_ops(self): + seed_ops = [self.h.op] + # Include all ops except for self.g.op + within_ops = [ + x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h] + ] + # For the fn, exclude self.c.op. + within_ops_fn = lambda op: op not in (self.c.op,) + stop_at_ts = (self.f,) + + with self.graph.as_default(): + # Backward walk only includes h since we stop at f and g is not within. + ops = ge.select.get_backward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set([self.h.op])) + + # If we do inclusive=False, the result is empty. + ops = ge.select.get_backward_walk_ops( + seed_ops, + inclusive=False, + within_ops=within_ops, + within_ops_fn=within_ops_fn, + stop_at_ts=stop_at_ts) + self.assertEqual(set(ops), set()) + + # Removing stop_at_fs adds f.op, d.op. + ops = ge.select.get_backward_walk_ops( + seed_ops, + inclusive=True, + within_ops=within_ops, + within_ops_fn=within_ops_fn) + self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op])) + + # Not using within_ops_fn adds back ops for a, b, c. + ops = ge.select.get_backward_walk_ops( + seed_ops, inclusive=True, within_ops=within_ops) + self.assertEqual( + set(ops), + set([ + self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op + ])) + + # Vanially backward search via self.h.op includes everything excpet e.op. + ops = ge.select.get_backward_walk_ops(seed_ops, inclusive=True) + self.assertEqual( + set(ops), + set([ + self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op, + self.h.op + ])) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ca00394388f67e2ed9508684a47b23c3ee9e79e8..97f38c923f4a19cedf3e16203ca1e66b7e5e45d2 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -18,11 +18,14 @@ from __future__ import division from __future__ import print_function import collections +import functools import numpy as np from tensorflow.contrib import graph_editor as ge from tensorflow.contrib.graph_editor.tests import match +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -41,6 +44,7 @@ class TransformTest(test.TestCase): self.graph = ops.Graph() with self.graph.as_default(): c0 = constant_op.constant(1.0, shape=[10], name="Const") + c0.op._set_attr("_foo", attr_value_pb2.AttrValue(s=b"foo")) c1 = constant_op.constant(1.0, shape=[10], name="Const") c2 = constant_op.constant(1.0, shape=[10], name="Const") i = constant_op.constant(1.0, shape=[10], name="Input") @@ -84,9 +88,9 @@ class TransformTest(test.TestCase): def test_transform(self): transformer = ge.Transformer() - def my_transform_op_handler(info, op): + def my_transform_op_handler(info, op, new_inputs): add_noise = op.name.startswith("Add") - op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + op_, op_outputs_ = ge.transform.copy_op_handler(info, op, new_inputs) if not add_noise: return op_, op_outputs_ # add some noise to op @@ -111,6 +115,32 @@ class TransformTest(test.TestCase): top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) + def test_transform_nodedef_fn(self): + transformer = ge.Transformer() + + def nodedef_fn(node_def): + if "_foo" in node_def.attr: + del node_def.attr["_foo"] + node_def.attr["_bar"].s = b"bar" + return node_def + + my_copy_op_handler = functools.partial( + ge.transform.copy_op_handler, nodedef_fn=nodedef_fn) + transformer.transform_op_handler = my_copy_op_handler + + graph = ops.Graph() + transformer(self.graph, graph, "", "") + + c0_before = self.graph.get_operation_by_name("Const") + c0_after = graph.get_operation_by_name("Const") + self.assertEquals(c0_before.get_attr("_foo"), b"foo") + with self.assertRaises(ValueError): + c0_after.get_attr("_foo") + + all_ops = graph.get_operations() + for op in all_ops: + self.assertEquals(op.get_attr("_bar"), b"bar") + def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = constant_op.constant(10.0, shape=[10], name="Input") @@ -201,15 +231,56 @@ class TransformTest(test.TestCase): get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. - self.assertEquals(original_mul1_grad._original_op.name, u"mul1") - self.assertEquals(result_mul1_grad._original_op.name, u"res/mul1") - self.assertNotEquals(res.name, g.name) + self.assertEqual(original_mul1_grad._original_op.name, u"mul1") + self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1") + self.assertNotEqual(res.name, g.name) with session.Session() as sess: sess.run(variables.global_variables_initializer()) g_val, res_val = sess.run([g, res]) self.assertNear(g_val, 0.0, ERROR_TOLERANCE) self.assertNear(res_val, 0.0, ERROR_TOLERANCE) + def test_graph_while_loop(self): + graph = ops.Graph() + with graph.as_default(): + max_index = array_ops.placeholder(dtype=dtypes.int32, shape=tuple()) + index_start = constant_op.constant(1) + sum_start = constant_op.constant(0) + _, result = control_flow_ops.while_loop( + cond=lambda i, unused_s: i <= max_index, + body=lambda i, s: (i + 1, s + i), + loop_vars=[index_start, sum_start]) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_max_index = copy_info.transformed(max_index) + with copied_graph.as_default(): + with session.Session() as sess: + n = 10 + sum_val = sess.run(copied_result, feed_dict={copied_max_index: n}) + self.assertEqual(sum_val, 55) + + def test_graph_cond(self): + graph = ops.Graph() + with graph.as_default(): + choice = array_ops.placeholder(shape=(), dtype=dtypes.bool) + result = control_flow_ops.cond( + choice, + lambda: constant_op.constant(1), + lambda: constant_op.constant(2)) + copied_graph = ops.Graph() + _, copy_info = ge.copy( + graph, dst_graph=copied_graph, dst_scope="imported") + copied_result = copy_info.transformed(result) + copied_choice = copy_info.transformed(choice) + with copied_graph.as_default(): + with session.Session() as sess: + res = sess.run(copied_result, feed_dict={copied_choice: True}) + self.assertEqual(res, 1) + res = sess.run(copied_result, feed_dict={copied_choice: False}) + self.assertEqual(res, 2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 14ac5296657d48c7f9e94d220c9e7e28af4d4353..a320a3f232fc1dc8c9ccfd1d0f2a9a40225db5cb 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -129,36 +129,51 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): return None -def copy_op_handler(info, op, copy_shape=True): +def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): """Copy a `tf.Operation`. Args: info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. + new_inputs: The new inputs for this op. copy_shape: also copy the shape of the tensor + nodedef_fn: If provided, a function that will be run on the NodeDef + and should return a mutated NodeDef before a new Operation is created. + This is useful as certain features cannot be set on the Operation and + must be modified in NodeDef. + Returns: A `(op, op_outputs)` tuple containing the transformed op and its outputs. """ + # The `new_inputs` was added to this function. For compatibility reason, + # let's raise an error if `new_inputs` is a boolean. + if isinstance(new_inputs, bool): + raise TypeError("the `new_inputs` argument must be an iterable.") + # pylint: disable=protected-access # Clone the node def: - node_def_ = deepcopy(op._node_def) + node_def_ = deepcopy(op.node_def) # Transform name: name_ = info.new_name(op.name) name_ = info.graph_.unique_name(name_) node_def_.name = name_ + # Mutate NodeDef if requested: + if nodedef_fn is not None: + node_def_ = nodedef_fn(node_def_) + # Copy the other inputs needed for initialization output_types_ = op._output_types[:] input_types_ = op._input_types[:] # Make a copy of the op_def too. # Its unique to every _type_ of Operation. - op_def_ = deepcopy(op._op_def) + op_def_ = deepcopy(op.op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_, [], input_types_, None, op_def_) # copy the shape over @@ -170,6 +185,7 @@ def copy_op_handler(info, op, copy_shape=True): # attribute to exist, we will create a dummy original_op first and then # later finalise it with the actual original_op when all the ops have # been copied. + # TODO(fkp): Stop worrying about _original_op and remove this code? if op._original_op: op_._original_op = op._original_op @@ -328,6 +344,14 @@ class _TmpInfo(object): for key in self.graph.get_all_collection_keys()) self.cyclic_ops = [] self.transform_original_op_handler = transform_op_if_inside_handler + # The graph is transformed op by op, in the same order the original ops + # were created. However, this is sometimes not possible due to cycles + # (i.e. while loops). So when the transformer creates a new op whose + # inputs do not exist yet, temporary placeholders are created and stored + # in this `tmp_cyclic_ts` container. During a second pass, + # those temporary tensors are replaced by the proper transformed tensors + # (see the function `_finalize_cycles`). + self.tmp_cyclic_ts = [] def new_name(self, name): """Compute a destination name from a source name. @@ -428,10 +452,10 @@ class Transformer(object): # Create temporary info used during this transform call info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) - info.transform_original_op_handler = self.transform_original_op_handler self._copy_ops(info) - self._connect_ops(info) + self._finalize_cycles(info) + self._connect_control_inputs(info) # Compute information about the transformation res_info = TransformerInfo(info) @@ -440,10 +464,10 @@ class Transformer(object): def _copy_ops(self, info): """Copy ops without connecting them.""" - for op in info.sgv.ops: - logging.debug("Copying op: %s", op.name) - # TODO(fkp): return a subgraph? - op_, op_outputs_ = self.transform_op_handler(info, op) + sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id) # pylint: disable=protected-access + for op in sorted_ops: + new_inputs = [self._transformed_t(info, t, op) for t in op.inputs] + op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs) if op is op_: raise ValueError("In-place transformation not allowed.") @@ -456,27 +480,36 @@ class Transformer(object): info.transformed_ts[op_output] = op_output_ self.assign_collections_handler(info, op_output, op_output_) - def _connect_ops(self, info): + def _finalize_cycles(self, info): + """Reconnects the cyclic tensors.""" + for t, tmp_t_, consumer_op in info.tmp_cyclic_ts: + if t not in info.transformed_ts: + raise ValueError("The tensor {} should be transformed by now.".format( + t.name)) + if consumer_op not in info.transformed_ops: + raise ValueError("The op {} should be transformed by now.".format( + consumer_op.name)) + t_ = info.transformed_ts[t] + consumer_op_ = info.transformed_ops[consumer_op] + t_index_ = list(consumer_op_.inputs).index(tmp_t_) + consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + + def _connect_control_inputs(self, info): """Connect the previously copied ops.""" for op in info.sgv.ops: - logging.debug("Finalizing op: %s", op.name) + logging.debug("Connecting control inputs of op: %s", op.name) op_ = info.transformed_ops[op] - # pylint: disable=protected-access - if op_.inputs: - raise ValueError("The newly transformed op should not have " - "any inputs yet: {}".format(op_.name)) - inputs_ = [self._transformed_t(info, t) for t in op.inputs] - for t in inputs_: - op_._add_input(t) - # Finalize original op. + # TODO(fkp): Stop worrying about _original_op and remove this code? + # pylint: disable=protected-access if op._original_op: - original_op = info.transform_original_op_handler(info, op._original_op) + original_op = self.transform_original_op_handler(info, op._original_op) if original_op is None: logging.debug("Could not find original op for: %s", op_.name) else: op_._original_op = original_op + # pylint: enable=protected-access # Finalize control inputs: control_inputs_ = [self.transform_control_input_handler(info, ci) @@ -525,19 +558,38 @@ class Transformer(object): return sgv_.remap(input_map_, output_map_) - def _transformed_t(self, info, t): + def _transformed_t(self, info, t, consumer_op): """Return tre transformed tensor of `t`.""" - if t not in info.transformed_ts: - # If op is not in the subgraph. - if t in info.sgv_inputs_set: - # t is an input of the subgraph. - return self.transform_external_input_handler(info, t) + if t in info.transformed_ts: + # If op is in the subgraph, just return its transformed counterpart. + return info.transformed_ts[t] + + if t in info.sgv_inputs_set: + # `t` is an input of the subgraph. + return self.transform_external_input_handler(info, t) + elif t.op in info.ops: + # `t` is an internal tensor but is not transformed yet because it + # belongs to a graph cycle. + logging.debug("Cyclic tensor: t.name = %s", t.name) + # Try to find an existing tensor we can use for now, + # otherwise create one. We'll rewire this later. + if consumer_op.type == "Merge": + first_input = consumer_op.inputs[0] + tmp_t_ = self._transformed_t(info, first_input, consumer_op) + elif t.op.type == "Enter": + enter_input = t.op.inputs[0] + tmp_t_ = self._transformed_t(info, enter_input, consumer_op) else: - # t is a hidden input of the subgraph. - return self.transform_external_hidden_input_handler(info, t) + with info.graph_.as_default(): + tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_, + prefix="geph_tmp") + logging.debug("Created temporary placeholder: %s.", tmp_t_.name) + # Register as temporary and return. + info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op)) + return tmp_t_ else: - # If op is in the subgraph, just return its transformed. - return info.transformed_ts[t] + # `t` is a hidden input of the subgraph. + return self.transform_external_hidden_input_handler(info, t) def copy(sgv, dst_graph=None, dst_scope="", src_scope="", @@ -624,6 +676,40 @@ def copy_with_input_replacements(sgv, replacement_ts, sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope) +def _add_control_flow_ops(ops, control_ios): + """Complete `ops` so that the tranformed graph is valid. + + Partially copying a graph can lead to a malformed graph. For instance, + copying half of a while construct is likely to result in an invalid graph. + This function attempts to add missing ops so that the transformation result + in a valid graph. + + Args: + ops: list of ops (modifed in-place). + control_ios: object created by a call to `util.ControlOutputs`. + """ + # Find while contexts. + control_flow_contexts = set() + for op in ops: + cfc = op._control_flow_context # pylint: disable=protected-access + if cfc: + control_flow_contexts.add(cfc) + # Find new ops. + new_ops = [] + for cfc in control_flow_contexts: + if cfc.IsWhileContext(): + new_ops += select.get_walks_intersection_ops( + [enter_t.op for enter_t in cfc.loop_enters], + [exit_t.op for exit_t in cfc.loop_exits], + control_ios=control_ios) + # Add new ops. + new_ops_set = set(new_ops) + ops_set = frozenset(ops) + for op in new_ops_set: + if op not in ops_set: + ops.append(op) + + def graph_replace(target_ts, replacement_ts, dst_scope="", src_scope="", reuse_dst_scope=False): """Create a new graph which compute the targets from the replaced Tensors. @@ -657,8 +743,13 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", control_ios=control_ios) if not ops: raise ValueError("Targets and replacements are not connected!") + + # Complete ops to avoid malformed control flow. + # TODO(fkp): Consider moving this function deeper (in the transformer?). + _add_control_flow_ops(ops, control_ios) + # Create a copy of the relevant subgraph - _, info = copy_with_input_replacements( + unused_sgv_, info = copy_with_input_replacements( ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) # Return the transformed targets but keep the original if the transformed # counterpart cannot be found diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 30bc33b9ee42ba78bc7307c67c0fc0af9f3356ef..584f4509ccc0aab30edc2be3bad7a9cb938d6e6a 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -38,6 +38,11 @@ __all__ = [ ] +# The graph editor sometimes need to create placeholders, they are named +# "geph_*". "geph" stands for Graph-Editor PlaceHolder. +_DEFAULT_PLACEHOLDER_PREFIX = "geph" + + def concatenate_unique(la, lb): """Add all the elements of `lb` to `la` if they are not there already. @@ -405,7 +410,7 @@ def scope_basename(scope): return scope[slash + 1:] -def placeholder_name(t=None, scope=None): +def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create placeholder name for the graph editor. Args: @@ -413,6 +418,7 @@ def placeholder_name(t=None, scope=None): on scope: absolute scope with which to prefix the placeholder's name. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A new placeholder name prefixed by "geph". Note that "geph" stands for Graph Editor PlaceHolder. This convention allows to quickly identify the @@ -430,19 +436,20 @@ def placeholder_name(t=None, scope=None): if scope is None: scope = op_dirname - if op_basename.startswith("geph__"): + if op_basename.startswith("{}__".format(prefix)): ph_name = op_basename else: - ph_name = "geph__{}_{}".format(op_basename, t.value_index) + ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) return scope + ph_name else: if scope is None: scope = "" - return scope + "geph" + return "{}{}".format(scope, prefix) -def make_placeholder_from_tensor(t, scope=None): +def make_placeholder_from_tensor(t, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a `tf.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -452,17 +459,19 @@ def make_placeholder_from_tensor(t, scope=None): (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of `t` is preserved. `""` means the root scope. + prefix: placeholder name prefix. Returns: A newly created `tf.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ return tf_array_ops.placeholder( - dtype=t.dtype, shape=t.get_shape(), name=placeholder_name( - t, scope=scope)) + dtype=t.dtype, shape=t.get_shape(), + name=placeholder_name(t, scope=scope, prefix=prefix)) -def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): +def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, + prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. @@ -474,11 +483,13 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. + prefix: placeholder name prefix. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder( - dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + dtype=dtype, shape=shape, + name=placeholder_name(scope=scope, prefix=prefix)) _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD index d601a1ec6f7a219bcd461d819ab2dfc64135a3ae..d0b44640667010b58c017d933d50ae5f87e8b275 100644 --- a/tensorflow/contrib/grid_rnn/BUILD +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -41,15 +41,3 @@ cuda_py_tests( "//tensorflow/python:variables", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py index 252788140f8c1906718c150574b963385b6ecfa1..bcd2a34c4e791a2ab66a439109145d6b78c14e22 100644 --- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py +++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py @@ -110,7 +110,7 @@ class GridRNNCell(rnn.RNNCell): logging.warning('%s: Using a concatenated state is slower and will ' 'soon be deprecated. Use state_is_tuple=True.', self) if not output_is_tuple: - logging.warning('%s: Using a concatenated output is slower and will' + logging.warning('%s: Using a concatenated output is slower and will ' 'soon be deprecated. Use output_is_tuple=True.', self) if num_dims < 1: diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index 1b528d7afc1112f5dc0667ae299ade02bc8fd04b..d65b2d6026dd89959aa62b57e07b073eef84572c 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -23,14 +23,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/hvx/README.md b/tensorflow/contrib/hvx/README.md index 163993a3f6bb1bedcdffb32944a98c7cc846878e..68e34f3b0938f795c8ad4c8c75226f6b0afe188d 100644 --- a/tensorflow/contrib/hvx/README.md +++ b/tensorflow/contrib/hvx/README.md @@ -42,11 +42,12 @@ If you've finished walking through the quick start guide, you may want to try bu ### Build libhexagon\_nn\_skel.so -Download Hexagon NN library from codeaurora.org and build it. +Download Hexagon NN library from codeaurora.org and build it. For Hexagon SDK 3.0, we need use the compatible version([721b2d58f](https://source.codeaurora.org/quic/hexagon_nn/nnlib/commit/?id=721b2d58f0f4e2d5b182f41e6b7c4db5356bf0fb)) of nnlib. ```shell git clone https://source.codeaurora.org/quic/hexagon_nn/nnlib cd nnlib +git reset 721b2d58f --hard ``` Just follow the instructions in `README.HOW_TO_BUILD`. You can find the file `libhexagon_nn_skel.so` in `hexagon_Release_dynamic_toolv72_v60/ship`. diff --git a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD index 324035100df366b80f57af9052c4bd935655b248..e39c60b252a1b49a68d51302fff47734869dddfe 100644 --- a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD +++ b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD @@ -13,18 +13,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//visibility:public"]) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - tf_cc_binary( name = "clock_cycle_profiling", testonly = 1, diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD index 909dc396a33b6fef1b2d51c3f52fab7782fc8ea5..0081fb61770075a2c36e92f65e01126f657edeb4 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/BUILD @@ -10,17 +10,6 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - tf_cc_binary( name = "hvx_ops_support_checker", testonly = 1, diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 3ff02e085ee63fabf42b3cc4389f4605455f3800..da450480b30b548484e69c61c85667d6dd390417 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -78,7 +78,10 @@ tf_custom_op_py_library( ], srcs_version = "PY2AND3", deps = [ + ":dense_image_warp_py", ":image_ops", + ":interpolate_spline_py", + ":sparse_image_warp_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", @@ -194,6 +197,117 @@ cuda_py_test( ], ) +py_library( + name = "dense_image_warp_py", + srcs = [ + "python/ops/dense_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_library( + name = "interpolate_spline_py", + srcs = [ + "python/ops/interpolate_spline.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +py_library( + name = "sparse_image_warp_py", + srcs = [ + "python/ops/sparse_image_warp.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dense_image_warp_py", + ":interpolate_spline_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +cuda_py_test( + name = "sparse_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/sparse_image_warp_test.py"], + additional_deps = [ + ":sparse_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], + data = [":sparse_image_warp_test_data"], + tags = ["no_pip"], +) + +filegroup( + name = "sparse_image_warp_test_data", + srcs = glob(["python/kernel_tests/test_data/*.png"]), +) + +cuda_py_test( + name = "dense_image_warp_test", + size = "medium", + srcs = ["python/kernel_tests/dense_image_warp_test.py"], + additional_deps = [ + ":dense_image_warp_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + +cuda_py_test( + name = "interpolate_spline_test", + size = "medium", + srcs = ["python/kernel_tests/interpolate_spline_test.py"], + additional_deps = [ + ":interpolate_spline_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:image_ops", + "//tensorflow/python:variables", + "//tensorflow/core:protos_all_py", + ], +) + tf_py_test( name = "segmentation_test", size = "medium", @@ -270,15 +384,3 @@ cuda_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index cc8ed117ba2edcc7a53e609381166f17a2fbb45e..e982030bc8959309e72d0f4e02b9755c48535a10 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -30,6 +30,9 @@ projective transforms (including rotation) are supported. @@transform @@translate @@translations_to_projective_transforms +@@dense_image_warp +@@interpolate_spline +@@sparse_image_warp ## Image Segmentation `Ops` @@ -47,6 +50,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.python.ops.dense_image_warp import dense_image_warp + from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq @@ -57,7 +62,9 @@ from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms +from tensorflow.contrib.image.python.ops.interpolate_spline import interpolate_spline from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms +from tensorflow.contrib.image.python.ops.sparse_image_warp import sparse_image_warp from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc index b71ff9cd507faac66b3a33d3c02ec9b5901d814a..645abbf0b0ea5465dadf55d065e997e16940c18d 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -53,7 +53,7 @@ void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, OP_REQUIRES_OK(ctx, ctx->allocate_temp( DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), &tranformation_matrix)); - // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // TODO(huangyp): It takes about 3.5 us to compute tranformation_matrix // with one thread. Improve its performance if necessary. internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc index fe8bf6e21c7b7310527668324571774e8bc50893..93722896233f0278c6cbb44af7203345e58c3172 100644 --- a/tensorflow/contrib/image/kernels/segmentation_ops.cc +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -101,8 +101,8 @@ struct ImageConnectedComponentsFunctor { int cost = (union_find.block_height() + union_find.block_width()) * 20; Shard(worker_threads->num_threads, worker_threads->workers, num_images * num_blocks_vertically * num_blocks_horizontally, cost, - [&union_find, num_images, num_blocks_vertically, - num_blocks_horizontally](int64 start_block, int64 limit_block) { + [&union_find, num_blocks_vertically, num_blocks_horizontally]( + int64 start_block, int64 limit_block) { for (int64 i = start_block; i < limit_block; i++) { int64 block_x = i % num_blocks_horizontally; int64 block_y = diff --git a/tensorflow/contrib/image/ops/distort_image_ops.cc b/tensorflow/contrib/image/ops/distort_image_ops.cc index b169b0b2b22ad6432baed2cc96711da5ca995875..ca49635d5d0bc7bb84b19508a74be74362d96ddf 100644 --- a/tensorflow/contrib/image/ops/distort_image_ops.cc +++ b/tensorflow/contrib/image/ops/distort_image_ops.cc @@ -36,9 +36,9 @@ REGISTER_OP("AdjustHsvInYiq") Adjust the YIQ hue of one or more images. `images` is a tensor of at least 3 dimensions. The last dimension is -interpretted as channels, and must be three. +interpreted as channels, and must be three. -We used linear transfomation described in: +We used linear transformation described in: beesbuzz.biz/code/hsv_color_transforms.php The input image is considered in the RGB colorspace. Conceptually, the RGB colors are first mapped into YIQ space, rotated around the Y channel by diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 68771b3d054a64ba94141c092e20df1ed6b2339b..ebdcaea7abae2a967786831b62b331897aa3f6a3 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -93,7 +93,7 @@ row_to_col_match_indices: A vector of length num_rows, which is the number of If `row_to_col_match_indices[i]` is not -1, row i is matched to column `row_to_col_match_indices[i]`. col_to_row_match_indices: A vector of length num_columns, which is the number - of columns of the input ditance matrix. + of columns of the input distance matrix. If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. )doc"); diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 8139d4272d6950815bd39a64e86e0f7422e6f799..bd784c6bda0344c092c1ae0af2c60be50fdff102 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -69,7 +69,7 @@ Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the -encode 3-D data witin the image. +encode 3-D data within the image. This Op is based upon: 'http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper' @@ -111,7 +111,7 @@ output_image_shape: Output size of returned image in X,Y, Channels 1-grayscale, output_data_window: Size of "DATA" window, must be equal to or smaller than 'output_image_shape', will be centered and use 'convergence_dots_size' for best fit to avoid overlap if possible -image:= A tensor of size 'output_image_shape' with the encloded 'depth_values' +image:= A tensor of size 'output_image_shape' with the encoded 'depth_values' )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a58b6a247ed6ae252db25a12f1e47c08c9a5c147 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/dense_image_warp_test.py @@ -0,0 +1,267 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes + +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import adam + + +class DenseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def test_interpolate_small_grid_ij(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_xy(self): + grid = constant_op.constant( + [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) + query_points = constant_op.constant( + [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]) + expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) + + interp = dense_image_warp._interpolate_bilinear( + grid, query_points, indexing='xy') + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def test_interpolate_small_grid_batched(self): + grid = constant_op.constant( + [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1]) + query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]], + [[0.5, 0.], [1., 0.], [1., 1.]]]) + expected_results = np.reshape( + np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1]) + + interp = dense_image_warp._interpolate_bilinear(grid, query_points) + + with self.test_session() as sess: + predicted = sess.run(interp) + self.assertAllClose(expected_results, predicted) + + def get_image_and_flow_placeholders(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + flow_shape = [batch_size, height, width, 2] + + tf_type = { + 'float16': dtypes.half, + 'float32': dtypes.float32, + 'float64': dtypes.float64 + } + + image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape) + + flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape) + return image, flows + + def get_random_image_and_flows(self, shape, image_type, flow_type): + batch_size, height, width, numchannels = shape + image_shape = [batch_size, height, width, numchannels] + image = np.random.normal(size=image_shape) + flow_shape = [batch_size, height, width, 2] + flows = np.random.normal(size=flow_shape) * 3 + return image.astype(image_type), flows.astype(flow_type) + + def assert_correct_interpolation_value(self, + image, + flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=False): + """Assert that the tf interpolation matches hand-computed value.""" + + height = image.shape[1] + width = image.shape[2] + displacement = flows[batch_index, y_index, x_index, :] + float_y = y_index - displacement[0] + float_x = x_index - displacement[1] + floor_y = max(min(height - 2, math.floor(float_y)), 0) + floor_x = max(min(width - 2, math.floor(float_x)), 0) + ceil_y = floor_y + 1 + ceil_x = floor_x + 1 + + alpha_y = min(max(0.0, float_y - floor_y), 1.0) + alpha_x = min(max(0.0, float_x - floor_x), 1.0) + + floor_y = int(floor_y) + floor_x = int(floor_x) + ceil_y = int(ceil_y) + ceil_x = int(ceil_x) + + top_left = image[batch_index, floor_y, floor_x, :] + top_right = image[batch_index, floor_y, ceil_x, :] + bottom_left = image[batch_index, ceil_y, floor_x, :] + bottom_right = image[batch_index, ceil_y, ceil_x, :] + + interp_top = alpha_x * (top_right - top_left) + top_left + interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left + interp = alpha_y * (interp_bottom - interp_top) + interp_top + atol = 1e-6 + rtol = 1e-6 + if low_precision: + atol = 1e-2 + rtol = 1e-3 + self.assertAllClose( + interp, + pred_interpolation[batch_index, y_index, x_index, :], + atol=atol, + rtol=rtol) + + def check_zero_flow_correctness(self, shape, image_type, flow_type): + """Assert using zero flows doesn't change the input image.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + rand_flows *= 0 + + predicted_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + self.assertAllClose(rand_image, predicted_interpolation) + + def test_zero_flows(self): + """Apply check_zero_flow_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] + for shape in shapes_to_try: + self.check_zero_flow_correctness( + shape, image_type='float32', flow_type='float32') + + def check_interpolation_correctness(self, + shape, + image_type, + flow_type, + num_probes=5): + """Interpolate, and then assert correctness for a few query locations.""" + + image, flows = self.get_image_and_flow_placeholders(shape, image_type, + flow_type) + interp = dense_image_warp.dense_image_warp(image, flows) + low_precision = image_type == 'float16' or flow_type == 'float16' + with self.test_session() as sess: + rand_image, rand_flows = self.get_random_image_and_flows( + shape, image_type, flow_type) + + pred_interpolation = sess.run( + interp, feed_dict={ + image: rand_image, + flows: rand_flows + }) + + for _ in range(num_probes): + batch_index = np.random.randint(0, shape[0]) + y_index = np.random.randint(0, shape[1]) + x_index = np.random.randint(0, shape[2]) + + self.assert_correct_interpolation_value( + rand_image, + rand_flows, + pred_interpolation, + batch_index, + y_index, + x_index, + low_precision=low_precision) + + def test_interpolation(self): + """Apply check_interpolation_correctness() for a few sizes and types.""" + + shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]] + for im_type in ['float32', 'float64', 'float16']: + for flow_type in ['float32', 'float64', 'float16']: + for shape in shapes_to_try: + self.check_interpolation_correctness(shape, im_type, flow_type) + + def test_gradients_exist(self): + """Check that backprop can run. + + The correctness of the gradients is assumed, since the forward propagation + is tested to be correct and we only use built-in tf ops. + However, we perform a simple test to make sure that backprop can actually + run. We treat the flows as a tf.Variable and optimize them to minimize + the difference between the interpolated image and the input image. + """ + + batch_size, height, width, numchannels = [4, 5, 6, 7] + image_shape = [batch_size, height, width, numchannels] + image = random_ops.random_normal(image_shape) + flow_shape = [batch_size, height, width, 2] + init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) + flows = variables.Variable(init_flows) + + interp = dense_image_warp.dense_image_warp(image, flows) + loss = math_ops.reduce_mean(math_ops.square(interp - image)) + + optimizer = adam.AdamOptimizer(1.0) + grad = gradients.gradients(loss, [flows]) + opt_func = optimizer.apply_gradients(zip(grad, [flows])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(10): + sess.run(opt_func) + + def test_size_exception(self): + """Make sure it throws an exception for images that are too small.""" + + shape = [1, 2, 1, 1] + msg = 'Should have raised an exception for invalid image size' + with self.assertRaises(ValueError, msg=msg): + self.check_interpolation_correctness(shape, 'float32', 'float32') + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1939caaa2d8586413cf9ecba6ce73cf64910d6fc --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/interpolate_spline_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for interpolate_spline.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import interpolate as sc_interpolate + +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util + +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + +from tensorflow.python.training import momentum + + +class _InterpolationProblem(object): + """Abstract class for interpolation problem descriptions.""" + + def get_problem(self, optimizable=False, extrapolate=True, dtype='float32'): + """Make data for an interpolation problem where all x vectors are n-d. + + Args: + optimizable: If True, then make train_points a tf.Variable. + extrapolate: If False, then clamp the query_points values to be within + the max and min of train_points. + dtype: The data type to use. + + Returns: + query_points, query_values, train_points, train_values: training and + test tensors for interpolation problem + """ + + # The values generated here depend on a seed of 0. + np.random.seed(0) + + batch_size = 1 + num_training_points = 10 + num_query_points = 4 + + init_points = np.random.uniform( + size=[batch_size, num_training_points, self.DATA_DIM]) + + init_points = init_points.astype(dtype) + train_points = ( + variables.Variable(init_points) + if optimizable else constant_op.constant(init_points)) + train_values = self.tf_function(train_points) + + query_points_np = np.random.uniform( + size=[batch_size, num_query_points, self.DATA_DIM]) + query_points_np = query_points_np.astype(dtype) + if not extrapolate: + query_points_np = np.clip(query_points_np, np.min(init_points), + np.max(init_points)) + + query_points = constant_op.constant(query_points_np) + query_values = self.np_function(query_points_np) + + return query_points, query_values, train_points, train_values + + +class _QuadraticPlusSinProblem1D(_InterpolationProblem): + """1D interpolation problem used for regression testing.""" + DATA_DIM = 1 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [6.2647187603, -7.84362604077, -5.63690142322, 1.42928896387], + (1.0, + 0.01): [6.77688289946, -8.02163669853, -5.79491157027, 1.4063285693], + (2.0, + 0.0): [8.67110264937, -8.41281390883, -5.80190044693, 1.50155606059], + (2.0, + 0.01): [6.70797816797, -7.49709587663, -5.28965776238, 1.52284731741], + (3.0, + 0.0): [9.37691802935, -8.50390141515, -5.80786417426, 1.63467762122], + (3.0, + 0.01): [4.47106304758, -5.71266128361, -3.92529303296, 1.86755293857], + (4.0, + 0.0): [9.58172461111, -8.51432104771, -5.80967675388, 1.63361164256], + (4.0, 0.01): [ + -3.87902711352, -0.0253462273846, 1.79857618022, -0.769339675725 + ] + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.power((x - 0.5), 3) - 0.25 * x + 10 * np.sin(x * 10), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_mean( + math_ops.pow((x - 0.5), 3) - 0.25 * x + 10 * math_ops.sin(x * 10), + 2, + keepdims=True) + + +class _QuadraticPlusSinProblemND(_InterpolationProblem): + """3D interpolation problem used for regression testing.""" + + DATA_DIM = 3 + HARDCODED_QUERY_VALUES = { + (1.0, 0.0): [1.06609663962, 1.28894849357, 1.10882405595, 1.63966936885], + (1.0, 0.01): [1.03123780748, 1.2952930985, 1.10366822954, 1.65265118569], + (2.0, 0.0): [0.627787735064, 1.43802857251, 1.00194632358, 1.91667538215], + (2.0, 0.01): [0.730159985046, 1.41702471595, 1.0065827217, 1.85758519312], + (3.0, 0.0): [0.350460417862, 1.67223539464, 1.00475331246, 2.31580322491], + (3.0, + 0.01): [0.624557250556, 1.63138876667, 0.976588193162, 2.12511237866], + (4.0, + 0.0): [0.898129669986, 1.24434133638, -0.938056116931, 1.59910338833], + (4.0, + 0.01): [0.0930360338179, -3.38791305538, -1.00969032567, 0.745535080382], + } + + def np_function(self, x): + """Takes np array, evaluates the test function, and returns np array.""" + return np.sum( + np.square(x - 0.5) + 0.25 * x + 1 * np.sin(x * 15), + axis=2, + keepdims=True) + + def tf_function(self, x): + """Takes tf tensor, evaluates the test function, and returns tf tensor.""" + return math_ops.reduce_sum( + math_ops.square(x - 0.5) + 0.25 * x + 1 * math_ops.sin(x * 15), + 2, + keepdims=True) + + +class InterpolateSplineTest(test_util.TensorFlowTestCase): + + def test_1d_linear_interpolation(self): + """For 1d linear interpolation, we can compare directly to scipy.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, train_values) = tp.get_problem( + extrapolate=False, dtype='float64') + interpolation_order = 1 + + with ops.name_scope('interpolator'): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order) + with self.test_session() as sess: + fetches = [query_points, train_points, train_values, interpolator] + query_points_, train_points_, train_values_, interp_ = sess.run(fetches) + + # Just look at the first element of the minibatch. + # Also, trim the final singleton dimension. + interp_ = interp_[0, :, 0] + query_points_ = query_points_[0, :, 0] + train_points_ = train_points_[0, :, 0] + train_values_ = train_values_[0, :, 0] + + # Compute scipy interpolation. + scipy_interp_function = sc_interpolate.interp1d( + train_points_, train_values_, kind='linear') + + scipy_interpolation = scipy_interp_function(query_points_) + scipy_interpolation_on_train = scipy_interp_function(train_points_) + + # Even with float64 precision, the interpolants disagree with scipy a + # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. + tol = 1e-3 + + self.assertAllClose( + train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol) + self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol) + + def test_1d_interpolation(self): + """Regression test for interpolation with 1-D points.""" + + tp = _QuadraticPlusSinProblem1D() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_nd_linear_interpolation(self): + """Regression test for interpolation with N-D points.""" + + tp = _QuadraticPlusSinProblemND() + (query_points, _, train_points, + train_values) = tp.get_problem(dtype='float64') + + for order in (1, 2, 3): + for reg_weight in (0, 0.01): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, order, reg_weight) + + target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] + target_interpolation = np.array(target_interpolation) + with self.test_session() as sess: + interp_val = sess.run(interpolator) + self.assertAllClose(interp_val[0, :, 0], target_interpolation) + + def test_interpolation_gradient(self): + """Make sure that backprop can run. Correctness of gradients is assumed. + + Here, we create a use a small 'training' set and a more densely-sampled + set of query points, for which we know the true value in advance. The goal + is to choose x locations for the training data such that interpolating using + this training data yields the best reconstruction for the function + values at the query points. The training data locations are optimized + iteratively using gradient descent. + """ + tp = _QuadraticPlusSinProblemND() + (query_points, query_values, train_points, + train_values) = tp.get_problem(optimizable=True) + + regularization = 0.001 + for interpolation_order in (1, 2, 3, 4): + interpolator = interpolate_spline.interpolate_spline( + train_points, train_values, query_points, interpolation_order, + regularization) + + loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator)) + + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [train_points]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [train_points])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(100): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0135c66e293693345c3da7fdb21e28ca6d160154 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/sparse_image_warp_test.py @@ -0,0 +1,254 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sparse_image_warp.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import sparse_image_warp + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import image_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + +from tensorflow.python.training import momentum + + +class SparseImageWarpTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def testGetBoundaryLocations(self): + image_height = 11 + image_width = 11 + num_points_per_edge = 4 + locs = sparse_image_warp._get_boundary_locations(image_height, image_width, + num_points_per_edge) + num_points = locs.shape[0] + self.assertEqual(num_points, 4 + 4 * num_points_per_edge) + locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)] + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (2, 4, 6, 8): + for j in (0, image_width - 1): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + for i in (0, image_height - 1): + for j in (2, 4, 6, 8): + self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) + + def testGetGridLocations(self): + image_height = 5 + image_width = 3 + grid = sparse_image_warp._get_grid_locations(image_height, image_width) + for i in range(image_height): + for j in range(image_width): + self.assertEqual(grid[i, j, 0], i) + self.assertEqual(grid[i, j, 1], j) + + def testZeroShift(self): + """Run assertZeroShift for various hyperparameters.""" + for order in (1, 2): + for regularization in (0, 0.01): + for num_boundary_points in (0, 1): + self.assertZeroShift(order, regularization, num_boundary_points) + + def assertZeroShift(self, order, regularization, num_boundary_points): + """Check that warping with zero displacements doesn't change the image.""" + batch_size = 1 + image_height = 4 + image_width = 4 + channels = 3 + + image = np.random.uniform( + size=[batch_size, image_height, image_width, channels]) + + input_image_op = constant_op.constant(np.float32(image)) + + control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + + control_point_displacements = np.zeros( + control_point_locations.shape.as_list()) + control_point_displacements = constant_op.constant( + np.float32(control_point_displacements)) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + regularization_weight=regularization, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, _ = sess.run( + [warped_image_op, input_image_op, flow_field]) + + self.assertAllClose(warped_image, input_image) + + def testMoveSinglePixel(self): + """Run assertMoveSinglePixel for various hyperparameters and data types.""" + for order in (1, 2): + for num_boundary_points in (1, 2): + for type_to_use in (dtypes.float32, dtypes.float64): + self.assertMoveSinglePixel(order, num_boundary_points, type_to_use) + + def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): + """Move a single block in a small grid using warping.""" + batch_size = 1 + image_height = 7 + image_width = 7 + channels = 3 + + image = np.zeros([batch_size, image_height, image_width, channels]) + image[:, 3, 3, :] = 1.0 + input_image_op = constant_op.constant(image, dtype=type_to_use) + + # Place a control point at the one white pixel. + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0)), + dtype=type_to_use) + # Shift it one pixel to the right. + control_point_displacements = [[0., 1.0]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0)), + dtype=type_to_use) + + (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( + input_image_op, + control_point_locations, + control_point_locations + control_point_displacements, + interpolation_order=order, + num_boundary_points=num_boundary_points) + + with self.test_session() as sess: + warped_image, input_image, flow = sess.run( + [warped_image_op, input_image_op, flow_field]) + # Check that it moved the pixel correctly. + self.assertAllClose( + warped_image[0, 4, 5, :], + input_image[0, 4, 4, :], + atol=1e-5, + rtol=1e-5) + + # Test that there is no flow at the corners. + for i in (0, image_height - 1): + for j in (0, image_width - 1): + self.assertAllClose( + flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5) + + def load_image(self, image_file, sess): + image_op = image_ops.decode_png( + io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3] + return sess.run(image_op) + + def testSmileyFace(self): + """Check warping accuracy by comparing to hardcoded warped images.""" + + test_data_dir = test.test_src_dir_path('contrib/image/python/' + 'kernel_tests/test_data/') + input_file = test_data_dir + 'Yellow_Smiley_Face.png' + with self.test_session() as sess: + input_image = self.load_image(input_file, sess) + control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111], + [180 - 39, 111], [90, 143], [58, 134], + [180 - 58, 134]]) # pyformat: disable + control_point_displacements = np.asarray( + [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25], + [10, 10.75]]) + control_points_op = constant_op.constant( + np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) + control_point_displacements_op = constant_op.constant( + np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0)) + float_image = np.expand_dims(np.float32(input_image) / 255, 0) + input_image_op = constant_op.constant(float_image) + + for interpolation_order in (1, 2, 3): + for num_boundary_points in (0, 1, 4): + warp_op, _ = sparse_image_warp.sparse_image_warp( + input_image_op, + control_points_op, + control_points_op + control_point_displacements_op, + interpolation_order=interpolation_order, + num_boundary_points=num_boundary_points) + with self.test_session() as sess: + warped_image = sess.run(warp_op) + out_image = np.uint8(warped_image[0, :, :, :] * 255) + target_file = ( + test_data_dir + + 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format( + interpolation_order, num_boundary_points)) + + target_image = self.load_image(target_file, sess) + + # Check that the target_image and out_image difference is no + # bigger than 2 (on a scale of 0-255). Due to differences in + # floating point computation on different devices, the float + # output in warped_image may get rounded to a different int + # than that in the saved png file loaded into target_image. + self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3) + + def testThatBackpropRuns(self): + """Run optimization to ensure that gradients can be computed.""" + + batch_size = 1 + image_height = 9 + image_width = 12 + image = variables.Variable( + np.float32( + np.random.uniform(size=[batch_size, image_height, image_width, 3]))) + control_point_locations = [[3., 3.]] + control_point_locations = constant_op.constant( + np.float32(np.expand_dims(control_point_locations, 0))) + control_point_displacements = [[0.25, -0.5]] + control_point_displacements = constant_op.constant( + np.float32(np.expand_dims(control_point_displacements, 0))) + warped_image, _ = sparse_image_warp.sparse_image_warp( + image, + control_point_locations, + control_point_locations + control_point_displacements, + num_boundary_points=3) + + loss = math_ops.reduce_mean(math_ops.abs(warped_image - image)) + optimizer = momentum.MomentumOptimizer(0.001, 0.9) + grad = gradients.gradients(loss, [image]) + grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) + opt_func = optimizer.apply_gradients(zip(grad, [image])) + init_op = variables.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(5): + sess.run([loss, opt_func]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png new file mode 100644 index 0000000000000000000000000000000000000000..7e303881e213a82e412d18de9d9d86f368726f06 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..7fd9e4e6d69f3120428d1d778846d495cea1a989 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..86d225e5d2158804f88dca881f69ed3ab287d866 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..37e8ffae114625d0cc6a07ab2b8dbbb7413a3829 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-1-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..e49b5816120d43a669264915f1b6747606e080e0 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..df3cf2004312ed0ed0ebf1f0340cbfec7fd9ac46 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..e1799a87c8542d7e515b6185d7e8f6f75fe73f3e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-2-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png new file mode 100644 index 0000000000000000000000000000000000000000..2c346e0ce5487e21d41aa4e6306fd83a7b4ffdb4 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-0.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png new file mode 100644 index 0000000000000000000000000000000000000000..6f8b65451cc08a463e4305ddc4be0dbe2879fae9 Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-1.png differ diff --git a/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png new file mode 100644 index 0000000000000000000000000000000000000000..8e78146d955ae8f02230121e6314f3285e87611e Binary files /dev/null and b/tensorflow/contrib/image/python/kernel_tests/test_data/Yellow_Smiley_Face_Warp-interp-3-clamp-4.png differ diff --git a/tensorflow/contrib/image/python/ops/dense_image_warp.py b/tensorflow/contrib/image/python/ops/dense_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b219ada492466919c615d8978e462e6c619d33 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/dense_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using per-pixel flow vectors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _interpolate_bilinear(grid, + query_points, + name='interpolate_bilinear', + indexing='ij'): + """Similar to Matlab's interp2 function. + + Finds values for query points on a grid using bilinear interpolation. + + Args: + grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. + query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. + name: a name for the operation (optional). + indexing: whether the query points are specified as row and column (ij), + or Cartesian coordinates (xy). + + Returns: + values: a 3-D `Tensor` with shape `[batch, N, channels]` + + Raises: + ValueError: if the indexing mode is invalid, or if the shape of the inputs + invalid. + """ + if indexing != 'ij' and indexing != 'xy': + raise ValueError('Indexing mode must be \'ij\' or \'xy\'') + + with ops.name_scope(name): + grid = ops.convert_to_tensor(grid) + query_points = ops.convert_to_tensor(query_points) + shape = grid.get_shape().as_list() + if len(shape) != 4: + msg = 'Grid must be 4 dimensional. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + batch_size, height, width, channels = shape + query_type = query_points.dtype + grid_type = grid.dtype + + if (len(query_points.get_shape()) != 3 or + query_points.get_shape()[2].value != 2): + msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received ' + 'size: ') + raise ValueError(msg + str(query_points.get_shape())) + + _, num_queries, _ = query_points.get_shape().as_list() + + if height < 2 or width < 2: + msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: ' + raise ValueError(msg + str(grid.get_shape())) + + alphas = [] + floors = [] + ceils = [] + + index_order = [0, 1] if indexing == 'ij' else [1, 0] + unstacked_query_points = array_ops.unstack(query_points, axis=2) + + for dim in index_order: + with ops.name_scope('dim-' + str(dim)): + queries = unstacked_query_points[dim] + + size_in_indexing_dimension = shape[dim + 1] + + # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 + # is still a valid index into the grid. + max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type) + min_floor = constant_op.constant(0.0, dtype=query_type) + floor = math_ops.minimum( + math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor) + int_floor = math_ops.cast(floor, dtypes.int32) + floors.append(int_floor) + ceil = int_floor + 1 + ceils.append(ceil) + + # alpha has the same type as the grid, as we will directly use alpha + # when taking linear combinations of pixel values from the image. + alpha = math_ops.cast(queries - floor, grid_type) + min_alpha = constant_op.constant(0.0, dtype=grid_type) + max_alpha = constant_op.constant(1.0, dtype=grid_type) + alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha) + + # Expand alpha to [b, n, 1] so we can use broadcasting + # (since the alpha values don't depend on the channel). + alpha = array_ops.expand_dims(alpha, 2) + alphas.append(alpha) + + if batch_size * height * width > np.iinfo(np.int32).max / 8: + error_msg = """The image size or batch size is sufficiently large + that the linearized addresses used by array_ops.gather + may exceed the int32 limit.""" + raise ValueError(error_msg) + + flattened_grid = array_ops.reshape(grid, + [batch_size * height * width, channels]) + batch_offsets = array_ops.reshape( + math_ops.range(batch_size) * height * width, [batch_size, 1]) + + # This wraps array_ops.gather. We reshape the image data such that the + # batch, y, and x coordinates are pulled into the first dimension. + # Then we gather. Finally, we reshape the output back. It's possible this + # code would be made simpler by using array_ops.gather_nd. + def gather(y_coords, x_coords, name): + with ops.name_scope('gather-' + name): + linear_coordinates = batch_offsets + y_coords * width + x_coords + gathered_values = array_ops.gather(flattened_grid, linear_coordinates) + return array_ops.reshape(gathered_values, + [batch_size, num_queries, channels]) + + # grab the pixel values in the 4 corners around each query point + top_left = gather(floors[0], floors[1], 'top_left') + top_right = gather(floors[0], ceils[1], 'top_right') + bottom_left = gather(ceils[0], floors[1], 'bottom_left') + bottom_right = gather(ceils[0], ceils[1], 'bottom_right') + + # now, do the actual interpolation + with ops.name_scope('interpolate'): + interp_top = alphas[1] * (top_right - top_left) + top_left + interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left + interp = alphas[0] * (interp_bottom - interp_top) + interp_top + + return interp + + +def dense_image_warp(image, flow, name='dense_image_warp'): + """Image warping using per-pixel flow vectors. + + Apply a non-linear warp to the image, where the warp is specified by a dense + flow field of offset vectors that define the correspondences of pixel values + in the output image back to locations in the source image. Specifically, the + pixel value at output[b, j, i, c] is + images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. + + The locations specified by this formula do not necessarily map to an int + index. Therefore, the pixel value is obtained by bilinear + interpolation of the 4 nearest pixels around + (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside + of the image, we use the nearest pixel values at the image boundary. + + + Args: + image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. + flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. + name: A name for the operation (optional). + + Note that image and flow can be of type tf.half, tf.float32, or tf.float64, + and do not necessarily have to be the same type. + + Returns: + A 4-D float `Tensor` with shape`[batch, height, width, channels]` + and same type as input image. + + Raises: + ValueError: if height < 2 or width < 2 or the inputs have the wrong number + of dimensions. + """ + with ops.name_scope(name): + batch_size, height, width, channels = image.get_shape().as_list() + # The flow is defined on the image grid. Turn the flow into a list of query + # points in the grid space. + grid_x, grid_y = array_ops.meshgrid( + math_ops.range(width), math_ops.range(height)) + stacked_grid = math_ops.cast( + array_ops.stack([grid_y, grid_x], axis=2), flow.dtype) + batched_grid = array_ops.expand_dims(stacked_grid, axis=0) + query_points_on_grid = batched_grid - flow + query_points_flattened = array_ops.reshape(query_points_on_grid, + [batch_size, height * width, 2]) + # Compute values at the query points, then reshape the result back to the + # image grid. + interpolated = _interpolate_bilinear(image, query_points_flattened) + interpolated = array_ops.reshape(interpolated, + [batch_size, height, width, channels]) + return interpolated diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index c139ae89d8d682d6b87813c3a21703ffa762f28e..cd984c80543886be1f682933e2e003bd3374e425 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -433,7 +433,7 @@ def bipartite_match(distance_mat, of rows of the input `distance_matrix`. If `row_to_col_match_indices[i]` is not -1, row i is matched to column `row_to_col_match_indices[i]`. col_to_row_match_indices: A vector of length num_columns, which is the - number of columns of the input ditance matrix. + number of columns of the input distance matrix. If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. """ diff --git a/tensorflow/contrib/image/python/ops/interpolate_spline.py b/tensorflow/contrib/image/python/ops/interpolate_spline.py new file mode 100644 index 0000000000000000000000000000000000000000..daf8c56456327f102f1409296a91f9f7b68ec799 --- /dev/null +++ b/tensorflow/contrib/image/python/ops/interpolate_spline.py @@ -0,0 +1,291 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Polyharmonic spline interpolation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + +EPSILON = 0.0000000001 + + +def _cross_squared_distance_matrix(x, y): + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y + Args: + x: [batch_size, n, d] float `Tensor` + y: [batch_size, m, d] float `Tensor` + + Returns: + squared_dists: [batch_size, n, m] float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 + """ + x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2) + y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2) + + # Expand so that we can broadcast. + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1) + + x_y_transpose = math_ops.matmul(x, y, adjoint_b=True) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile + + return squared_dists + + +def _pairwise_squared_distance_matrix(x): + """Pairwise squared distance among a (batch) matrix's rows (2nd dim). + + This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x) + + Args: + x: `[batch_size, n, d]` float `Tensor` + + Returns: + squared_dists: `[batch_size, n, n]` float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2 + """ + + x_x_transpose = math_ops.matmul(x, x, adjoint_b=True) + x_norm_squared = array_ops.matrix_diag_part(x_x_transpose) + x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2) + + # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose( + x_norm_squared_tile, [0, 2, 1]) + + return squared_dists + + +def _solve_interpolation(train_points, train_values, order, + regularization_weight): + """Solve for interpolation coefficients. + + Computes the coefficients of the polyharmonic interpolant for the 'training' + data defined by (train_points, train_values) using the kernel phi. + + Args: + train_points: `[b, n, d]` interpolation centers + train_values: `[b, n, k]` function values + order: order of the interpolation + regularization_weight: weight to place on smoothness regularization term + + Returns: + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + """ + + b, n, d = train_points.get_shape().as_list() + _, _, k = train_values.get_shape().as_list() + + # First, rename variables so that the notation (c, f, w, v, A, B, etc.) + # follows https://en.wikipedia.org/wiki/Polyharmonic_spline. + # To account for python style guidelines we use + # matrix_a for A and matrix_b for B. + + c = train_points + f = train_values + + # Next, construct the linear system. + with ops.name_scope('construct_linear_system'): + + matrix_a = _phi(_pairwise_squared_distance_matrix(c), order) # [b, n, n] + if regularization_weight > 0: + batch_identity_matrix = np.expand_dims(np.eye(n), 0) + batch_identity_matrix = constant_op.constant( + batch_identity_matrix, dtype=train_points.dtype) + + matrix_a += regularization_weight * batch_identity_matrix + + # Append ones to the feature values for the bias term in the linear model. + ones = array_ops.ones([b, n, 1], train_points.dtype) + matrix_b = array_ops.concat([c, ones], 2) # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = array_ops.concat( + [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1) + + num_b_cols = matrix_b.get_shape()[2] # d + 1 + lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype) + right_block = array_ops.concat([matrix_b, lhs_zeros], + 1) # [b, n + d + 1, d + 1] + lhs = array_ops.concat([left_block, right_block], + 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype) + rhs = array_ops.concat([f, rhs_zeros], 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + with ops.name_scope('solve_linear_system'): + w_v = linalg_ops.matrix_solve(lhs, rhs) + w = w_v[:, :n, :] + v = w_v[:, n:, :] + + return w, v + + +def _apply_interpolation(query_points, train_points, w, v, order): + """Apply polyharmonic interpolation model to data. + + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at + train_points: `[b, n, d]` x values that act as the interpolation centers + ( the c variables in the wikipedia article) + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + order: order of the interpolation + + Returns: + Polyharmonic interpolation evaluated at points defined in query_points. + """ + + batch_size = train_points.get_shape()[0].value + num_query_points = query_points.get_shape()[1].value + + # First, compute the contribution from the rbf term. + pairwise_dists = _cross_squared_distance_matrix(query_points, train_points) + phi_pairwise_dists = _phi(pairwise_dists, order) + + rbf_term = math_ops.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + query_points_pad = array_ops.concat([ + query_points, + array_ops.ones([batch_size, num_query_points, 1], train_points.dtype) + ], 2) + linear_term = math_ops.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def _phi(r, order): + """Coordinate-wise nonlinearity used to define the order of the interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + + Args: + r: input op + order: interpolation order + + Returns: + phi_k evaluated coordinate-wise on r, for k = r + """ + + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + with ops.name_scope('phi'): + if order == 1: + r = math_ops.maximum(r, EPSILON) + r = math_ops.sqrt(r) + return r + elif order == 2: + return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON)) + elif order == 4: + return 0.5 * math_ops.square(r) * math_ops.log( + math_ops.maximum(r, EPSILON)) + elif order % 2 == 0: + r = math_ops.maximum(r, EPSILON) + return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r) + else: + r = math_ops.maximum(r, EPSILON) + return math_ops.pow(r, 0.5 * order) + + +def interpolate_spline(train_points, + train_values, + query_points, + order, + regularization_weight=0.0, + name='interpolate_spline'): + r"""Interpolate signal using polyharmonic interpolation. + + The interpolant has the form + $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$ + + This is a sum of two terms: (1) a weighted sum of radial basis function (RBF) + terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias. + The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v + by appending 1 as a final dimension to x. The coefficients w and v are + estimated such that the interpolant exactly fits the value of the function at + the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the + vector w sums to 0. With these constraints, the coefficients can be obtained + by solving a linear system. + + \\(\phi\\) is an RBF, parametrized by an interpolation + order. Using order=2 produces the well-known thin-plate spline. + + We also provide the option to perform regularized interpolation. Here, the + interpolant is selected to trade off between the squared loss on the training + data and a certain measure of its curvature + ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)). + Using a regularization weight greater than zero has the effect that the + interpolant will no longer exactly fit the training data. However, it may be + less vulnerable to overfitting, particularly for high-order interpolation. + + Note the interpolation procedure is differentiable with respect to all inputs + besides the order parameter. + + Args: + train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional + locations. These do not need to be regularly-spaced. + train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values + evaluated at train_points. + query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations + where we will output the interpolant's values. + order: order of the interpolation. Common values are 1 for + \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline), + or 3 for \\(\phi(r) = r^3\\). + regularization_weight: weight placed on the regularization term. + This will depend substantially on the problem, and it should always be + tuned. For many problems, it is reasonable to use no regularization. + If using a non-zero value, we recommend a small value like 0.001. + name: name prefix for ops created by this function + + Returns: + `[b, m, k]` float `Tensor` of query values. We use train_points and + train_values to perform polyharmonic interpolation. The query values are + the values of the interpolant evaluated at the locations specified in + query_points. + """ + with ops.name_scope(name): + train_points = ops.convert_to_tensor(train_points) + train_values = ops.convert_to_tensor(train_values) + query_points = ops.convert_to_tensor(query_points) + + # First, fit the spline to the observed data. + with ops.name_scope('solve'): + w, v = _solve_interpolation(train_points, train_values, order, + regularization_weight) + + # Then, evaluate the spline at the query locations. + with ops.name_scope('predict'): + query_values = _apply_interpolation(query_points, train_points, w, v, + order) + + return query_values diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index d4a6a5bcbb52511d4093587814100b2a0e8b2420..0ceb683ff4c6965a5ee4bcb04846a69d4d8ea0a5 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -45,7 +45,7 @@ def single_image_random_dot_stereograms(depth_values, Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may - corrupt the encode 3-D data witin the image. + corrupt the encode 3-D data within the image. Based upon [this paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). diff --git a/tensorflow/contrib/image/python/ops/sparse_image_warp.py b/tensorflow/contrib/image/python/ops/sparse_image_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..54a215d6db6ded56a1a4a018a7e176f35fe6397e --- /dev/null +++ b/tensorflow/contrib/image/python/ops/sparse_image_warp.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image warping using sparse flow defined at control points.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.image.python.ops import dense_image_warp +from tensorflow.contrib.image.python.ops import interpolate_spline + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +def _get_grid_locations(image_height, image_width): + """Wrapper for np.meshgrid.""" + + y_range = np.linspace(0, image_height - 1, image_height) + x_range = np.linspace(0, image_width - 1, image_width) + y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') + return np.stack((y_grid, x_grid), -1) + + +def _expand_to_minibatch(np_array, batch_size): + """Tile arbitrarily-sized np_array to include new batch dimension.""" + tiles = [batch_size] + [1] * np_array.ndim + return np.tile(np.expand_dims(np_array, 0), tiles) + + +def _get_boundary_locations(image_height, image_width, num_points_per_edge): + """Compute evenly-spaced indices along edge of image.""" + y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) + x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) + ys, xs = np.meshgrid(y_range, x_range, indexing='ij') + is_boundary = np.logical_or( + np.logical_or(xs == 0, xs == image_width - 1), + np.logical_or(ys == 0, ys == image_height - 1)) + return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) + + +def _add_zero_flow_controls_at_boundary(control_point_locations, + control_point_flows, image_height, + image_width, boundary_points_per_edge): + """Add control points for zero-flow boundary conditions. + + Augment the set of control points with extra points on the + boundary of the image that have zero flow. + + Args: + control_point_locations: input control points + control_point_flows: their flows + image_height: image height + image_width: image width + boundary_points_per_edge: number of points to add in the middle of each + edge (not including the corners). + The total number of points added is + 4 + 4*(boundary_points_per_edge). + + Returns: + merged_control_point_locations: augmented set of control point locations + merged_control_point_flows: augmented set of control point flows + """ + + batch_size = control_point_locations.get_shape()[0].value + + boundary_point_locations = _get_boundary_locations(image_height, image_width, + boundary_points_per_edge) + + boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) + + type_to_use = control_point_locations.dtype + boundary_point_locations = constant_op.constant( + _expand_to_minibatch(boundary_point_locations, batch_size), + dtype=type_to_use) + + boundary_point_flows = constant_op.constant( + _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use) + + merged_control_point_locations = array_ops.concat( + [control_point_locations, boundary_point_locations], 1) + + merged_control_point_flows = array_ops.concat( + [control_point_flows, boundary_point_flows], 1) + + return merged_control_point_locations, merged_control_point_flows + + +def sparse_image_warp(image, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundary_points=0, + name='sparse_image_warp'): + """Image warping using correspondences between sparse control points. + + Apply a non-linear warp to the image, where the warp is specified by + the source and destination locations of a (potentially small) number of + control points. First, we use a polyharmonic spline + (@{tf.contrib.image.interpolate_spline}) to interpolate the displacements + between the corresponding control points to a dense flow field. + Then, we warp the image using this dense flow field + (@{tf.contrib.image.dense_image_warp}). + + Let t index our control points. For regularization_weight=0, we have: + warped_image[b, dest_control_point_locations[b, t, 0], + dest_control_point_locations[b, t, 1], :] = + image[b, source_control_point_locations[b, t, 0], + source_control_point_locations[b, t, 1], :]. + + For regularization_weight > 0, this condition is met approximately, since + regularized interpolation trades off smoothness of the interpolant vs. + reconstruction of the interpolant at the control points. + See @{tf.contrib.image.interpolate_spline} for further documentation of the + interpolation_order and regularization_weight arguments. + + + Args: + image: `[batch, height, width, channels]` float `Tensor` + source_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + dest_control_point_locations: `[batch, num_control_points, 2]` float + `Tensor` + interpolation_order: polynomial order used by the spline interpolation + regularization_weight: weight on smoothness regularizer in interpolation + num_boundary_points: How many zero-flow boundary points to include at + each image edge.Usage: + num_boundary_points=0: don't add zero-flow points + num_boundary_points=1: 4 corners of the image + num_boundary_points=2: 4 corners and one in the middle of each edge + (8 points total) + num_boundary_points=n: 4 corners and n-1 along each edge + name: A name for the operation (optional). + + Note that image and offsets can be of type tf.half, tf.float32, or + tf.float64, and do not necessarily have to be the same type. + + Returns: + warped_image: `[batch, height, width, channels]` float `Tensor` with same + type as input image. + flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense + flow field produced by the interpolation. + """ + + image = ops.convert_to_tensor(image) + source_control_point_locations = ops.convert_to_tensor( + source_control_point_locations) + dest_control_point_locations = ops.convert_to_tensor( + dest_control_point_locations) + + control_point_flows = ( + dest_control_point_locations - source_control_point_locations) + + clamp_boundaries = num_boundary_points > 0 + boundary_points_per_edge = num_boundary_points - 1 + + with ops.name_scope(name): + + batch_size, image_height, image_width, _ = image.get_shape().as_list() + + # This generates the dense locations where the interpolant + # will be evaluated. + grid_locations = _get_grid_locations(image_height, image_width) + + flattened_grid_locations = np.reshape(grid_locations, + [image_height * image_width, 2]) + + flattened_grid_locations = constant_op.constant( + _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) + + if clamp_boundaries: + (dest_control_point_locations, + control_point_flows) = _add_zero_flow_controls_at_boundary( + dest_control_point_locations, control_point_flows, image_height, + image_width, boundary_points_per_edge) + + flattened_flows = interpolate_spline.interpolate_spline( + dest_control_point_locations, control_point_flows, + flattened_grid_locations, interpolation_order, regularization_weight) + + dense_flows = array_ops.reshape(flattened_flows, + [batch_size, image_height, image_width, 2]) + + warped_image = dense_image_warp.dense_image_warp(image, dense_flows) + + return warped_image, dense_flows diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 9d6b4d5d87e24d72b29ab33ee805fe0d068cc30a..0e34315db45d61282af1882631dc769a72965c3e 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -114,14 +114,3 @@ tf_cc_tests( "//tensorflow/core:testlib", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/input_pipeline/kernels/BUILD b/tensorflow/contrib/input_pipeline/kernels/BUILD index f20a6e38d4e80f869e9274d6fc49338a95fc6788..797605b8fe66e8375edcc70668a07a8d2a6d73f3 100644 --- a/tensorflow/contrib/input_pipeline/kernels/BUILD +++ b/tensorflow/contrib/input_pipeline/kernels/BUILD @@ -17,14 +17,3 @@ cc_library( ], alwayslink = 1, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/integrate/BUILD b/tensorflow/contrib/integrate/BUILD index 66948c1ea1f3f239d3f43a57626f8c229fe24ad9..0b7d64f4edd7587000ca5b9ecae257fe8fedd4a1 100644 --- a/tensorflow/contrib/integrate/BUILD +++ b/tensorflow/contrib/integrate/BUILD @@ -42,14 +42,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py index 68bf511099ab473d158108df6ff07827819297d9..694f0c14bd4e74535c70fab76c5f7ac58f452559 100644 --- a/tensorflow/contrib/integrate/__init__.py +++ b/tensorflow/contrib/integrate/__init__.py @@ -18,6 +18,7 @@ See the @{$python/contrib.integrate} guide. @@odeint +@@odeint_fixed """ from __future__ import absolute_import diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD index efb403462a6e5df5b69ac0735ffc03f40d4a252c..3913c9dc7abfba2829bde5e86fe2927e8fc29a9d 100644 --- a/tensorflow/contrib/kafka/BUILD +++ b/tensorflow/contrib/kafka/BUILD @@ -1,66 +1,93 @@ -package( - default_visibility = ["//visibility:private"], -) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") -load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_py_test", +) -tf_kernel_library( - name = "kafka_kernels", +py_library( + name = "kafka", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + ], +) + +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = [":dataset_kernels"], +) + +tf_gen_op_libs( + op_lib_names = ["dataset_ops"], +) + +cc_library( + name = "dataset_kernels", srcs = ["kernels/kafka_dataset_ops.cc"], - visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/kernels:bounds_check_lib", - "//tensorflow/core/kernels:dataset", + "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@kafka", + "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) -tf_gen_op_libs( - op_lib_names = ["kafka_ops"], +py_library( + name = "dataset_ops", + srcs = [ + "python/ops/kafka_dataset_ops.py", + ], + srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:lib", + ":kafka_op_loader", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", ], ) tf_gen_op_wrapper_py( - name = "gen_kafka_ops", - out = "python/ops/gen_kafka_ops.py", - require_shape_functions = True, - deps = [":kafka_ops_op_lib"], + name = "gen_dataset_ops", + out = "python/ops/gen_dataset_ops.py", + deps = ["//tensorflow/contrib/kafka:dataset_ops_op_lib"], ) -py_library( - name = "kafka", - srcs = [ - "__init__.py", - "python/ops/kafka_dataset_ops.py", +tf_kernel_library( + name = "dataset_ops_kernels", + deps = [ + ":dataset_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "kafka_op_loader", + srcs = ["python/ops/kafka_op_loader.py"], + dso = ["//tensorflow/contrib/kafka:_dataset_ops.so"], + kernels = [ + ":dataset_ops_kernels", + "//tensorflow/contrib/kafka:dataset_ops_op_lib", ], srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ - ":gen_kafka_ops", + ":gen_dataset_ops", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//tensorflow/python/data/ops:readers", ], ) @@ -88,18 +115,7 @@ tf_py_test( ], tags = [ "manual", + "no_windows", "notap", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index 88ef5f357113372b0a2d0cb13382ac980a61252d..a4cd4a2cc4b99b5906185bd2b942ed15c1ddf5e4 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/dataset.h" - -#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/dataset.h" #include "src-cpp/rdkafkacpp.h" diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/dataset_ops.cc similarity index 100% rename from tensorflow/contrib/kafka/ops/kafka_ops.cc rename to tensorflow/contrib/kafka/ops/dataset_ops.cc diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py index 8e51d27a342359881de072c3979a2b5a7fc034ea..a1624614d1ab1be31463c5cdc0b4cfb653165a0c 100644 --- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -17,8 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.kafka.python.ops import gen_kafka_ops -from tensorflow.python.data.ops.readers import Dataset +from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import +from tensorflow.contrib.kafka.python.ops import gen_dataset_ops +from tensorflow.python.data.ops.dataset_ops import Dataset from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -58,8 +59,8 @@ class KafkaDataset(Dataset): timeout, dtype=dtypes.int64, name="timeout") def _as_variant_tensor(self): - return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, - self._eof, self._timeout) + return gen_dataset_ops.kafka_dataset(self._topics, self._servers, + self._group, self._eof, self._timeout) @property def output_classes(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py similarity index 75% rename from tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py rename to tensorflow/contrib/kafka/python/ops/kafka_op_loader.py index 690a44ff4368663306733300a1ea70397fb93e1e..ec2fdea962ef946d3f8f32b9e630b92649d612fe 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py +++ b/tensorflow/contrib/kafka/python/ops/kafka_op_loader.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experimental methods for tf.feature_column sequential input.""" - +"""Python helper for loading kafka ops and kernels.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 7e0019ce4ad6c96e09ac9e222e2f4e2840273983..7a4cab20d1a3471af2a2a402a6d1443a90fa7f9b 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -52,15 +52,3 @@ py_library( "//tensorflow/python/keras", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index eff7dfeb4c1117e40f4faf43c5e92a52cffd6528..87c2dcd89b63fa9f92d93c87abce91fd3460d44e 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -90,15 +90,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index f182fef067b7f523bc5ca63227265be40528b171..4ef0a66a52429233c6e6f70667a451466493629c 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -43,10 +43,10 @@ def sparse_multiclass_hinge_loss( This is a generalization of standard (binary) hinge loss. For a given instance with correct label c*, the loss is given by: - loss = max_{c != c*} logits_c - logits_{c*} + 1. + $$loss = max_{c != c*} logits_c - logits_{c*} + 1.$$ or equivalently - loss = max_c { logits_c - logits_{c*} + I_{c != c*} } - where I_{c != c*} = 1 if c != c* and 0 otherwise. + $$loss = max_c { logits_c - logits_{c*} + I_{c != c*} }$$ + where \\(I_{c != c*} = 1\ \text{if}\ c != c*\\) and 0 otherwise. Args: labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py index 9dc01124ab195ae17b8795a11e4ebefe3f2c746b..9a721a9d440e66eb30bb94daf2b6878318f1e75f 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py @@ -34,33 +34,31 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): r"""Class that implements Random Fourier Feature Mapping (RFFM) in TensorFlow. The RFFM mapping is used to approximate the Gaussian (RBF) kernel: - ``` - exp(-||x-y||_2^2 / (2 * sigma^2)) - ``` + $$(exp(-||x-y||_2^2 / (2 * \sigma^2))$$ The implementation of RFFM is based on the following paper: "Random Features for Large-Scale Kernel Machines" by Ali Rahimi and Ben Recht. (link: https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) - The mapping uses a matrix `Omega \in R^{d x D}` and a bias vector `b \in R^D` - where `d` is the input dimension (number of dense input features) and `D` is - the output dimension (i.e., dimension of the feature space the input is mapped - to). Each entry of `Omega` is sampled i.i.d. from a (scaled) Gaussian - distribution and each entry of `b` is sampled independently and uniformly from - [0, 2 * pi]. - - For a single input feature vector x in R^d, its RFFM is defined as: - ``` - sqrt(2/D) * cos(x * Omega + b) - ``` - where `cos` is the element-wise cosine function and `x, b` are represented as - row vectors. The aforementioned paper shows that the linear kernel of - RFFM-mapped vectors approximates the Gaussian kernel of the initial vectors. + The mapping uses a matrix \\(\Omega \in R^{d x D}\\) and a bias vector + \\(b \in R^D\\) where \\(d\\) is the input dimension (number of dense input + features) and \\(D\\) is the output dimension (i.e., dimension of the feature + space the input is mapped to). Each entry of \\(\Omega\\) is sampled i.i.d. + from a (scaled) Gaussian distribution and each entry of \\(b\\) is sampled + independently and uniformly from [0, \\(2 * \pi\\)]. + + For a single input feature vector \\(x \in R^d\\), its RFFM is defined as: + $$\sqrt(2/D) * cos(x * \Omega + b)$$ + + where \\(cos\\) is the element-wise cosine function and \\(x, b\\) are + represented as row vectors. The aforementioned paper shows that the linear + kernel of RFFM-mapped vectors approximates the Gaussian kernel of the initial + vectors. """ def __init__(self, input_dim, output_dim, stddev=1.0, seed=1, name=None): - """Constructs a RandomFourierFeatureMapper instance. + r"""Constructs a RandomFourierFeatureMapper instance. Args: input_dim: The dimension (number of features) of the tensors to be mapped. @@ -68,11 +66,11 @@ class RandomFourierFeatureMapper(dkm.DenseKernelMapper): stddev: The standard deviation of the Gaussian kernel to be approximated. The error of the classifier trained using this approximation is very sensitive to this parameter. - seed: An integer used to initialize the parameters (`Omega` and `b`) of - the mapper. For repeatable sequences across different invocations of the - mapper object (for instance, to ensure consistent mapping both at - training and eval/inference if these happen in different invocations), - set this to the same integer. + seed: An integer used to initialize the parameters (\\(\Omega\\) and + \\(b\\)) of the mapper. For repeatable sequences across different + invocations of the mapper object (for instance, to ensure consistent + mapping both at training and eval/inference if these happen in + different invocations), set this to the same integer. name: name for the mapper object. """ # TODO(sibyl-vie3Poto): Maybe infer input_dim and/or output_dim (if not explicitly diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 6f4a264485993ab737723171409042b4a9673669..91929184a2e6f3cccae92cb819501a7c6ef81673 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -34,7 +34,7 @@ def _inner_product(x, y): """Inner product between tensors x and y. The input tensors are assumed to be in ROW representation, that is, the method - returns x * y^T. + returns \\(x * y^T\\). Args: x: input tensor in row format diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD index 9a5759bf14f753bbc50d3ef8f54ceab7daf745ab..b719046b37ac761d56e8d5aa34772103be691cd6 100644 --- a/tensorflow/contrib/kfac/BUILD +++ b/tensorflow/contrib/kfac/BUILD @@ -24,15 +24,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD index 89965eda374b2b403f680fc77eb923d0e660d1e2..8186fa1c62cb952f86614a96c3965bcddae1686e 100644 --- a/tensorflow/contrib/kfac/examples/BUILD +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -28,8 +28,28 @@ py_library( ) py_binary( - name = "convnet_mnist_main", - srcs = ["convnet_mnist_main.py"], + name = "convnet_mnist_single_main", + srcs = ["convnet_mnist_single_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_multi_tower_main", + srcs = ["convnet_mnist_multi_tower_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_distributed_main", + srcs = ["convnet_mnist_distributed_main.py"], srcs_version = "PY2AND3", deps = [ ":convnet", @@ -58,15 +78,3 @@ py_library( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 39d80addaac1fe855a37255b32bf4412b99df46a..e8e3353091df25e135b1247bf976bb9ce177d1a7 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -37,6 +37,8 @@ import tensorflow as tf from tensorflow.contrib.kfac.examples import mlp from tensorflow.contrib.kfac.examples import mnist +from tensorflow.contrib.kfac.python.ops import optimizer as opt + lc = tf.contrib.kfac.layer_collection oq = tf.contrib.kfac.op_queue @@ -48,12 +50,18 @@ __all__ = [ "linear_layer", "build_model", "minimize_loss_single_machine", - "minimize_loss_distributed", + "distributed_grads_only_and_ops_chief_worker", + "distributed_grads_and_ops_dedicated_workers", "train_mnist_single_machine", - "train_mnist_distributed", + "train_mnist_distributed_sync_replicas", + "train_mnist_multitower" ] +# Inverse update ops will be run every _INVERT_EVRY iterations. +_INVERT_EVERY = 10 + + def conv_layer(layer_id, inputs, kernel_size, out_channels): """Builds a convolutional layer with ReLU non-linearity. @@ -161,8 +169,9 @@ def build_model(examples, labels, num_labels, layer_collection): accuracy = tf.reduce_mean( tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) + with tf.device("/cpu:0"): + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) # Register parameters. K-FAC needs to know about the inputs, outputs, and # parameters of each conv/fully connected layer and the logits powering the @@ -181,41 +190,59 @@ def build_model(examples, labels, num_labels, layer_collection): def minimize_loss_single_machine(loss, accuracy, layer_collection, + device="/gpu:0", session_config=None): """Minimize loss with K-FAC on a single machine. - A single Session is responsible for running all of K-FAC's ops. + A single Session is responsible for running all of K-FAC's ops. The covariance + and inverse update ops are placed on `device`. All model variables are on CPU. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and invserse + update ops are run on this device. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: final value for 'accuracy'. """ # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() + g_step = tf.train.get_or_create_global_step() optimizer = opt.KfacOptimizer( learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=[device], + inv_devices=[device], momentum=0.9) - train_op = optimizer.minimize(loss, global_step=global_step) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) - if global_step_ % 100 == 0: + if (global_step_ + 1) % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -250,16 +277,62 @@ def _num_gradient_tasks(num_tasks): return int(np.ceil(0.6 * num_tasks)) -def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection): - """Minimize loss with an synchronous implementation of K-FAC. +def _make_distributed_train_op( + task_id, + num_worker_tasks, + num_ps_tasks, + layer_collection +): + """Creates optimizer and distributed training op. - Different tasks are responsible for different parts of K-FAC's Ops. The first - 60% of tasks update weights; the next 20% accumulate covariance statistics; - the last 20% invert the matrices used to precondition gradients. + Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes + the train op. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC + optimizer. + optimizer: Instance of `opt.KfacOptimizer`. + global_step: `tensor`, Global step. + """ + tf.logging.info("Task id : %d", task_id) + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), + total_num_replicas=num_worker_tasks) + return sync_optimizer, optimizer, global_step + + +def distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection, invert_every=10): + """Minimize loss with a synchronous implementation of K-FAC. + + All workers perform gradient computation. Chief worker applies gradient after + averaging the gradients obtained from all the workers. All workers block + execution untill the update is applied. Chief worker runs covariance and + inverse update ops. Covariance and inverse matrices are placed on parameter + servers in a round robin manner. For further details on synchronous + distributed optimization check `tf.train.SyncReplicasOptimizer`. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. If 0, parameter servers are not used. @@ -271,6 +344,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, run with each step. layer_collection: LayerCollection instance describing model architecture. Used by K-FAC to construct preconditioner. + invert_every: `int`, Number of steps between update the inverse. Returns: final value for 'accuracy'. @@ -278,19 +352,80 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, Raises: ValueError: if task_id >= num_worker_tasks. """ - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) - train_op = sync_optimizer.minimize(loss, global_step=global_step) + + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + if is_chief: + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + update_op = tf.cond( + tf.equal(tf.mod(global_step + 1, invert_every), 0), + lambda: make_update_op(inv_update_thunks), + tf.no_op) + else: + update_op = train_op + + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, update_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + return accuracy_ + + +def distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, + loss, accuracy, layer_collection): + """Minimize loss with a synchronous implementation of K-FAC. + + Different workers are responsible for different parts of K-FAC's Ops. The + first 60% of tasks compute gradients; the next 20% accumulate covariance + statistics; the last 20% invert the matrices used to precondition gradients. + The chief worker applies the gradient . + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + sync_optimizer, optimizer, global_step = _make_distributed_train_op( + task_id, num_worker_tasks, num_ps_tasks, layer_collection) + _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() + train_op = sync_optimizer.minimize(loss, global_step=global_step) + inv_update_queue = oq.OpQueue(inv_update_ops) tf.logging.info("Starting training.") is_chief = (task_id == 0) @@ -306,7 +441,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, if _is_gradient_task(task_id, num_worker_tasks): learning_op = train_op elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = optimizer.cov_update_op + learning_op = cov_update_op elif _is_inv_update_task(task_id, num_worker_tasks): # TODO(duckworthd): Running this op before cov_update_op has been run a # few times can result in "InvalidArgumentError: Cholesky decomposition @@ -324,13 +459,18 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, return accuracy_ -def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): +def train_mnist_single_machine(data_dir, + num_epochs, + use_fake_data=False, + device="/gpu:0"): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. + device: string, Either '/cpu:0' or '/gpu:0'. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. @@ -350,22 +490,38 @@ def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. - return minimize_loss_single_machine(loss, accuracy, layer_collection) + return minimize_loss_single_machine( + loss, accuracy, layer_collection, device=device) def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True): + use_fake_data=True, devices=None): """Train a ConvNet on MNIST. + Training data is split equally among the towers. Each tower computes loss on + its own batch of data and the loss is aggregated on the CPU. The model + variables are placed on first tower. The covariance and inverse update ops + and variables are placed on GPUs in a round robin manner. + Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. + devices: string, Either list of CPU or GPU. The covaraince and inverse + update ops are run on this device. Returns: accuracy of model on the final minibatch of training data. """ + if devices: + device_count = {"GPU": num_towers} + else: + device_count = {"CPU": num_towers} + + devices = devices or [ + "/cpu:{}".format(tower_id) for tower_id in range(num_towers) + ] # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 @@ -388,7 +544,7 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, layer_collection = lc.LayerCollection() tower_results = [] for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): + with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) @@ -402,34 +558,79 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, accuracy = tf.reduce_mean(accuracies) # Fit model. + session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize_loss_single_machine( - loss, accuracy, layer_collection, session_config=session_config) + allow_soft_placement=False, + device_count=device_count, + ) + + g_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices, + momentum=0.9) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + train_op = optimizer.minimize(loss, global_step=g_step) -def train_mnist_distributed(task_id, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - use_fake_data=False): - """Train a ConvNet on MNIST. + def make_update_op(update_thunks): + update_op = [thunk() for thunk in update_thunks] + return tf.group(*update_op) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([train_op, cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _ = sess.run( + [g_step, loss, accuracy, inverse_op]) + + if (global_step_ + 1) % _INVERT_EVERY == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + +def train_mnist_distributed_sync_replicas(task_id, + is_chief, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + op_strategy, + use_fake_data=False): + """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. + op_strategy: `string`, Strategy to run the covariance and inverse + ops. If op_strategy == `chief_worker` then covaraiance and inverse + update ops are run on chief worker otherwise they are run on dedicated + workers. + use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. + + Raises: + ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") @@ -448,9 +649,17 @@ def train_mnist_distributed(task_id, # Fit model. checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, - master, checkpoint_dir, loss, accuracy, - layer_collection) + if op_strategy == "chief_worker": + return distributed_grads_only_and_ops_chief_worker( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + elif op_strategy == "dedicated_workers": + return distributed_grads_and_ops_dedicated_workers( + task_id, is_chief, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection) + else: + raise ValueError("Only supported op strategies are : {}, {}".format( + "chief_worker", "dedicated_workers")) if __name__ == "__main__": diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c2d4a9e9bfcc4bfb55a25d2f23e66afe5b1375 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +Distributed training with sync replicas optimizer. See +`convnet.train_mnist_distributed_sync_replicas` for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl import flags +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_integer("task", -1, "Task identifier") +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") +flags.DEFINE_string( + "cov_inv_op_strategy", "chief_worker", + "In dist training mode run the cov, inv ops on chief or dedicated workers." +) +flags.DEFINE_string("master", "local", "Session master.") +flags.DEFINE_integer("ps_tasks", 2, + "Number of tasks in the parameter server job.") +flags.DEFINE_integer("replicas_to_aggregate", 5, + "Number of replicas to aggregate.") +flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") +flags.DEFINE_integer("num_epochs", None, "Number of epochs.") + + +def _is_chief(): + """Determines whether a job is the chief worker.""" + if "chief_worker" in FLAGS.brain_jobs: + return FLAGS.brain_job_name == "chief_worker" + else: + return FLAGS.task == 0 + + +def main(unused_argv): + _ = unused_argv + convnet.train_mnist_distributed_sync_replicas( + FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, + FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py similarity index 57% rename from tensorflow/contrib/kfac/examples/convnet_mnist_main.py rename to tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py index b0c6fbde198850c76af0bc1600dc23e926227229..4249bf8a8d9d3a5beb87d4140a55b0ee6eadbc64 100644 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py @@ -14,44 +14,35 @@ # ============================================================================== r"""Train a ConvNet on MNIST using K-FAC. -See convnet.py for details. +Multi tower training mode. See `convnet.train_mnist_multitower` for details. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import argparse -import sys +from absl import flags import tensorflow as tf from tensorflow.contrib.kfac.examples import convnet -FLAGS = None +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") +flags.DEFINE_integer("num_towers", 2, + "Number of towers for multi tower training.") -def main(argv): - _ = argv - - if FLAGS.num_towers > 1: - convnet.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) +def main(unused_argv): + _ = unused_argv + assert FLAGS.num_towers > 1 + devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] + convnet.train_mnist_multitower( + FLAGS.data_dir, + num_epochs=200, + num_towers=FLAGS.num_towers, + devices=devices) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) + tf.app.run(main=main) diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py similarity index 63% rename from tensorflow/contrib/bayesflow/python/ops/custom_grad.py rename to tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py index ca1ecb9c40204c3c723fa3423cfe148e823adc28..2c1f09936073a34816da61d771f59e848b8787af 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad.py +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py @@ -12,23 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for specifying custom gradients. +r"""Train a ConvNet on MNIST using K-FAC. -See ${python/contrib.bayesflow.custom_gradient}. +Train on single machine. See `convnet.train_mnist_single_machine` for details. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.custom_grad_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ - 'custom_gradient', -] +from absl import flags +import tensorflow as tf -remove_undocumented(__name__, _allowed_symbols) +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = flags.FLAGS +flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") + + +def main(unused_argv): + convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) + + +if __name__ == "__main__": + tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD index ce7da95c124beaed4773d68ce0d0c41f187f7c9d..ede7f183fe24f26bd86e232e831dea5f8ea1fdc4 100644 --- a/tensorflow/contrib/kfac/examples/tests/BUILD +++ b/tensorflow/contrib/kfac/examples/tests/BUILD @@ -50,15 +50,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 8d86c2bb5150cd4bc8a2b21ba050e904929e0fe9..6de775cc79953ba548c766e861d6d88e0455a508 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -112,15 +112,16 @@ class ConvNetTest(tf.test.TestCase): def testMinimizeLossSingleMachine(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy, - layer_collection) - self.assertLess(accuracy_, 1.0) + accuracy_ = convnet.minimize_loss_single_machine( + loss, accuracy, layer_collection, device="/cpu:0") + self.assertLess(accuracy_, 2.0) def testMinimizeLossDistributed(self): with tf.Graph().as_default(): loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_distributed( + accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", @@ -128,7 +129,7 @@ class ConvNetTest(tf.test.TestCase): loss=loss, accuracy=accuracy, layer_collection=layer_collection) - self.assertLess(accuracy_, 1.0) + self.assertLess(accuracy_, 2.0) def testTrainMnistSingleMachine(self): with tf.Graph().as_default(): @@ -138,7 +139,7 @@ class ConvNetTest(tf.test.TestCase): # but there are too few parameters for the model to effectively memorize # the training set the way an MLP can. convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True) + data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") def testTrainMnistMultitower(self): with tf.Graph().as_default(): @@ -149,13 +150,15 @@ class ConvNetTest(tf.test.TestCase): def testTrainMnistDistributed(self): with tf.Graph().as_default(): # Ensure model training doesn't crash. - convnet.train_mnist_distributed( + convnet.train_mnist_distributed_sync_replicas( task_id=0, + is_chief=True, num_worker_tasks=1, num_ps_tasks=0, master="", data_dir=None, num_epochs=1, + op_strategy="chief_worker", use_fake_data=True) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index f4ed978174a9ddd8b54a88e60bfb48a67a2e76d2..2477d2bfc12c2df64a672fd457e9634009ccd129 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -36,6 +36,7 @@ py_test( srcs = ["fisher_factors_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -113,6 +114,7 @@ py_test( name = "utils_test", srcs = ["utils_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/contrib/tpu", @@ -154,15 +156,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index bfdb69ad02caaa57827e0ae6b3c9fc0d0ed03754..f22dbcf21566297340f3b4158a810f6d03af12f5 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -40,30 +39,6 @@ from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] -class DeviceContextGeneratorTest(test.TestCase): - - def testNoDevice(self): - device_context_generator = estimator._DeviceContextGenerator(None) - with ops.device("/device:CPU:0"): # This is what will be used - with device_context_generator(): # Does nothing - a = constant_op.constant([2.0], name="a") - self.assertEqual("/device:CPU:0", a.op.device) - - def testTwoDevices(self): - device_context_generator = estimator._DeviceContextGenerator( - ["/device:GPU:0", "/device:GPU:1"]) - with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes - with device_context_generator(): - a = constant_op.constant([2.0], name="a") - with device_context_generator(): - b = constant_op.constant([2.0], name="b") - with device_context_generator(): - c = constant_op.constant([2.0], name="c") - self.assertEqual("/device:GPU:0", a.op.device) - self.assertEqual("/device:GPU:1", b.op.device) - self.assertEqual("/device:GPU:0", c.op.device) - - class EstimatorTest(test.TestCase): def setUp(self): @@ -90,57 +65,113 @@ class EstimatorTest(test.TestCase): def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, - self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights, self.bias], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection + ) + est.make_ops_and_vars() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_ops_and_vars() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection) + est.make_ops_and_vars() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, - "not_a_real_mode") + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="not_a_real_mode") + est.make_ops_and_vars() - def testModeListCorrect(self): + def testGradientsModeBuild(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection) - self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="gradients") + est.make_ops_and_vars() - def testAllModesBuild(self): - for mode in _ALL_ESTIMATION_MODES: - with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, - self.layer_collection, mode) + def testEmpiricalModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="empirical") + est.make_ops_and_vars() + + def testCurvaturePropModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="curvature_prop") + est.make_ops_and_vars() + + def testExactModeBuild(self): + with self._graph.as_default(): + est = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + cov_ema_decay=0.1, + damping=0.2, + layer_collection=self.layer_collection, + estimation_mode="exact") + est.make_ops_and_vars() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, - cov_ema_decay=0.0, - damping=0.0) + damping=0.2, + cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, cov_update_op_thunks, _, + _) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] - cov_update_op_thunks = fisher_estimator.cov_update_thunks cov_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks)]) @@ -172,23 +203,61 @@ class EstimatorTest(test.TestCase): sess.run(cov_update_op) sess.run(increment_global_step) + def test_round_robin_placement(self): + """Check if the ops and variables are placed on devices correctly.""" + with self._graph.as_default(): + fisher_estimator = estimator.FisherEstimatorRoundRobin( + variables=[self.weights], + layer_collection=self.layer_collection, + damping=0.2, + cov_ema_decay=0.0, + cov_devices=["/cpu:{}".format(i) for i in range(2)], + inv_devices=["/cpu:{}".format(i) for i in range(2)]) + + # Construct an op that executes one covariance update per step. + (cov_update_ops, _, inv_update_ops, _, _, + _) = fisher_estimator.make_ops_and_vars(scope="test") + self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") + self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") + self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") + self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() + ] + self.assertEqual(cov_matrices[0].device, "/device:CPU:0") + self.assertEqual(cov_matrices[1].device, "/device:CPU:1") + # Inverse matrices need to be explicitly placed. + self.assertEqual(inv_matrices[0].device, "") + self.assertEqual(inv_matrices[1].device, "") + def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: - fisher_estimator = estimator.FisherEstimator( + fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, - cov_ema_decay=0.0, - damping=0.0) + damping=0.2, + cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() + (cov_variable_thunks, _, inv_variable_thunks, + inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() + for thunk in cov_variable_thunks: + thunk() + for thunk in inv_variable_thunks: + thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._inverses_by_damping.values() + for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] - inv_update_op_thunks = fisher_estimator.inv_update_thunks inv_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 82accd57f0c37d140238f1884fce956654d14227..6eda6c31e34370fd2bea1192ebf777924824c8e3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops @@ -62,7 +63,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -71,7 +72,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -80,7 +81,7 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -90,9 +91,12 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -108,9 +112,12 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -126,10 +133,13 @@ class FullFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) damping = 0.5 block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() + block.register_inverse() + block._factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) @@ -153,7 +163,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -162,7 +172,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -171,7 +181,7 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -181,9 +191,10 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -199,9 +210,10 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = params**2 block.instantiate_factors((grads,), 0.5) + block._factor.instantiate_cov_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -216,10 +228,11 @@ class NaiveDiagonalFBTest(test.TestCase): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_minibatch(32) + block.register_additional_tower(32) grads = (params[0]**2, math_ops.sqrt(params[1])) damping = 0.5 block.instantiate_factors((grads,), damping) + block._factor.instantiate_cov_variables() cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) sess.run(state_ops.assign(block._factor._cov, cov)) @@ -236,10 +249,10 @@ class NaiveDiagonalFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) -class FullyConnectedDiagonalFB(test.TestCase): +class FullyConnectedDiagonalFBTest(test.TestCase): def setUp(self): - super(FullyConnectedDiagonalFB, self).setUp() + super(FullyConnectedDiagonalFBTest, self).setUp() self.batch_size = 4 self.input_size = 6 @@ -311,8 +324,8 @@ class FullyConnectedDiagonalFB(test.TestCase): self.assertAllClose(expected_result, result) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -363,9 +376,10 @@ class FullyConnectedDiagonalFB(test.TestCase): block = fb.FullyConnectedDiagonalFB( lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) @@ -375,6 +389,70 @@ class FullyConnectedDiagonalFB(test.TestCase): return multiply_result, multiply_inverse_result +class EmbeddingKFACFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_tower(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(((grads,),), damping) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_tower(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Create a sparse update. + indices = array_ops.constant([1, 3, 4]) + values = array_ops.constant([[1.], [1.], [1.]]) + sparse_vector = ops.IndexedSlices( + values, indices, dense_shape=[vocab_size, 1]) + dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) + + # Compare Fisher-vector product against explicit result. + result = block.multiply_inverse(sparse_vector) + expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), + dense_vector) + + sess.run(tf_variables.global_variables_initializer()) + self.assertAlmostEqual( + sess.run(expected_result[1]), sess.run(result.values[0])) + self.assertAlmostEqual( + sess.run(expected_result[3]), sess.run(result.values[1])) + self.assertAlmostEqual( + sess.run(expected_result[4]), sess.run(result.values[2])) + + class FullyConnectedKFACBasicFBTest(test.TestCase): def testFullyConnectedKFACBasicFBInit(self): @@ -383,7 +461,7 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) @@ -393,10 +471,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): @@ -404,10 +482,10 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -415,9 +493,15 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -441,9 +525,14 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): inputs = array_ops.constant([[1., 2.], [3., 4.]]) outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -464,13 +553,20 @@ class FullyConnectedKFACBasicFBTest(test.TestCase): outputs = array_ops.zeros([32, output_dim]) params = array_ops.zeros([input_dim, output_dim]) block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_minibatch(inputs, outputs) + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) @@ -593,8 +689,8 @@ class ConvDiagonalFBTest(test.TestCase): self.assertAllClose(expected_result, result, atol=1e-3) - def testRegisterAdditionalMinibatch(self): - """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + def testRegisterAdditionalTower(self): + """Ensure 1 big tower and 2 small towers are equivalent.""" multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( @@ -655,9 +751,10 @@ class ConvDiagonalFBTest(test.TestCase): block = fb.ConvDiagonalFB( lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') for (i, o) in zip(inputs, outputs): - block.register_additional_minibatch(i, o) + block.register_additional_tower(i, o) block.instantiate_factors((output_grads,), damping=0.0) + block._factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) sess.run(block._factor.make_covariance_update_op(0.0)) @@ -667,6 +764,54 @@ class ConvDiagonalFBTest(test.TestCase): return multiply_result, multiply_inverse_result +class DepthwiseConvKFCBasicFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((3, 3, 8, 2)) + inputs = random_ops.random_normal((32, 5, 5, 8)) + outputs = random_ops.random_normal((32, 5, 5, 16)) + layer_collection = lc.LayerCollection() + block = fb.DepthwiseConvKFCBasicFB( + layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') + block.register_additional_tower(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() + + # Ensure inverse update op doesn't crash. + sess.run(tf_variables.global_variables_initializer()) + sess.run([ + factor.make_inverse_update_ops() + for factor in layer_collection.get_factors() + ]) + + # Ensure inverse-vector multiply doesn't crash. + output = block.multiply_inverse(params) + sess.run(output) + + # Ensure same shape. + self.assertAllEqual(output.shape, params.shape) + + class ConvKFCBasicFBTest(test.TestCase): def _testConvKFCBasicFBInitParams(self, params): @@ -678,16 +823,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.constant(params) inputs = random_ops.random_normal((2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertAllEqual([outputs], block.tensors_to_compute_grads()) def testConvKFCBasicFBInitParamsParamsTuple(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) def testConvKFCBasicFBInitParamsParamsSingle(self): - self._testConvKFCBasicFBInitParams([np.array([1., 2.])]) + self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) def testMultiplyInverseTuple(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -695,11 +841,16 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -721,12 +872,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = random_ops.random_normal((2, 2, 2, 2)) inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertFalse(block._has_bias) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -744,12 +900,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = [random_ops.random_normal((2, 2, 2, 2))] inputs = random_ops.random_normal((2, 2, 2, 2)) outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) self.assertTrue(block._has_bias) grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) + block.instantiate_factors(((grads,),), 0.5) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() # Make sure our inverse is something other than the identity. sess.run(tf_variables.global_variables_initializer()) @@ -767,12 +928,17 @@ class ConvKFCBasicFBTest(test.TestCase): params = array_ops.zeros((2, 2, 2, 2)) inputs = array_ops.zeros((2, 2, 2, 2)) outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), - 'SAME') - block.register_additional_minibatch(inputs, outputs) + block = fb.ConvKFCBasicFB( + lc.LayerCollection(), params=params, padding='SAME') + block.register_additional_tower(inputs, outputs) grads = outputs**2 damping = 0. # This test is only valid without damping. - block.instantiate_factors(([grads],), damping) + block.instantiate_factors(((grads,),), damping) + block._input_factor.instantiate_cov_variables() + block._output_factor.instantiate_cov_variables() + block.register_inverse() + block._input_factor.instantiate_inv_variables() + block._output_factor.instantiate_inv_variables() sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) @@ -797,9 +963,9 @@ class FullyConnectedSeriesFBTest(test.TestCase): random_seed.set_random_seed(200) inputs = array_ops.constant([1., 2.]) outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) + block.register_additional_tower([inputs], [outputs]) + self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) def testInstantiateFactorsHasBias(self): with ops.Graph().as_default(): @@ -808,11 +974,10 @@ class FullyConnectedSeriesFBTest(test.TestCase): outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), - inputs=[inputs], - outputs=[outputs], has_bias=True) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) + block.instantiate_factors((((grads,),),), 0.5) def testInstantiateFactorsNoBias(self): with ops.Graph().as_default(): @@ -821,11 +986,10 @@ class FullyConnectedSeriesFBTest(test.TestCase): outputs = array_ops.constant([[3., 4.], [5., 6.]]) block = fb.FullyConnectedSeriesFB( lc.LayerCollection(), - inputs=[inputs], - outputs=[outputs], has_bias=False) + block.register_additional_tower([inputs], [outputs]) grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) + block.instantiate_factors((((grads,),),), 0.5) def as_tensors(tensor_or_tuple): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 753378d9f4a0d8762bafbee2ec27d6c71783dda1..2a3592c53fdda488561e504ba2712aadc3214cc4 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np import numpy.random as npr +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,36 +30,13 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test -class MaybeColocateTest(test.TestCase): - - def setUp(self): - self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS - - def tearDown(self): - ff.set_global_constants( - colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs) - - def testFalse(self): - ff.set_global_constants(colocate_cov_ops_with_inputs=False) - with tf_ops.Graph().as_default(): - a = constant_op.constant([2.0], name='a') - with ff.maybe_colocate_with(a): - b = constant_op.constant(3.0, name='b') - self.assertEqual([b'loc:@a'], a.op.colocation_groups()) - self.assertEqual([b'loc:@b'], b.op.colocation_groups()) - - def testTrue(self): - ff.set_global_constants(colocate_cov_ops_with_inputs=True) - with tf_ops.Graph().as_default(): - a = constant_op.constant([2.0], name='a') - with ff.maybe_colocate_with(a): - b = constant_op.constant(3.0, name='b') - self.assertEqual([b'loc:@a'], a.op.colocation_groups()) - self.assertEqual([b'loc:@a'], b.op.colocation_groups()) +def make_damping_func(damping): + return fb._package_func(lambda: damping, damping) class FisherFactorTestingDummy(ff.FisherFactor): @@ -89,6 +67,30 @@ class FisherFactorTestingDummy(ff.FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + return NotImplementedError + + def left_multiply(self, x, damping): + return NotImplementedError + + def right_multiply(self, x, damping): + return NotImplementedError + + def left_multiply_matpower(self, x, exp, damping): + return NotImplementedError + + def right_multiply_matpower(self, x, exp, damping): + return NotImplementedError + + def instantiate_inv_variables(self): + return NotImplementedError + + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -120,6 +122,12 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def instantiate_covariance(self): pass + def _num_towers(self): + raise NotImplementedError + + def _get_data_device(self): + raise NotImplementedError + class NumericalUtilsTest(test.TestCase): @@ -231,21 +239,24 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' - dampings = 0.1, 1e-1, 0.00001, 1e-5 + damping_funcs = [make_damping_func(0.1), + make_damping_func(0.1), + make_damping_func(1e-5), + make_damping_func(1e-5)] + for damping_func in damping_funcs: + factor.register_inverse(damping_func) - for damping in dampings: - factor.register_damped_inverse(damping) + factor.instantiate_inv_variables() - self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys())) - inv = factor._inverses_by_damping[dampings[0]] - self.assertEqual(inv, factor._inverses_by_damping[dampings[1]]) - self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]]) - self.assertEqual(factor._inverses_by_damping[dampings[2]], - factor._inverses_by_damping[dampings[3]]) + inv = factor.get_inverse(damping_funcs[0]) + self.assertEqual(inv, factor.get_inverse(damping_funcs[1])) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2])) + self.assertEqual(factor.get_inverse(damping_funcs[2]), + factor.get_inverse(damping_funcs[3])) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]], - factor_vars) + self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]), + set(factor_vars)) self.assertEqual(shape, inv.get_shape()) def testRegisterMatpower(self): @@ -255,17 +266,22 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' - factor.register_matpower(1, 0.5) - factor.register_matpower(2, 0.5) + # TODO(b/74201126): Change to using the same func for both once + # Topohash is in place. + damping_func_1 = make_damping_func(0.5) + damping_func_2 = make_damping_func(0.5) + + factor.register_matpower(-0.5, damping_func_1) + factor.register_matpower(2, damping_func_2) + + factor.instantiate_inv_variables() - self.assertEqual( - set([(1, 0.5), (2, 0.5)]), - set(factor._matpower_by_exp_and_damping.keys())) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - matpower1 = factor.get_matpower(1, 0.5) - matpower2 = factor.get_matpower(2, 0.5) - self.assertListEqual([matpower1, matpower2], factor_vars) + matpower1 = factor.get_matpower(-0.5, damping_func_1) + matpower2 = factor.get_matpower(2, damping_func_2) + + self.assertEqual(set([matpower1, matpower2]), set(factor_vars)) self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower2.get_shape()) @@ -284,17 +300,24 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + damping_funcs = [] for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): - factor.register_damped_inverse(1. / i) + damping_funcs.append(make_damping_func(1./i)) + + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): + factor.register_inverse(damping_funcs[i]) + + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] sess.run(ops) - for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) + new_invs.append(sess.run(factor.get_inverse(damping_funcs[i]))) + # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): for j in range(i + 1, len(new_invs)): @@ -309,14 +332,16 @@ class InverseProvidingFactorTest(test.TestCase): factor._cov = array_ops.constant(cov, dtype=dtypes.float32) exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power damping = 0.5 + damping_func = make_damping_func(damping) - factor.register_matpower(exp, damping) + factor.register_matpower(exp, damping_func) + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) sess.run(ops[0]) - matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)]) + matpower = sess.run(factor.get_matpower(exp, damping_func)) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) self.assertAllClose(matpower, matpower_np) @@ -327,18 +352,21 @@ class InverseProvidingFactorTest(test.TestCase): factor = InverseProvidingFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - factor.register_damped_inverse(0) + damping_func = make_damping_func(0) + + factor.register_inverse(damping_func) + factor.instantiate_inv_variables() ops = factor.make_inverse_update_ops() self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor._inverses_by_damping[0]) + old_inv = sess.run(factor.get_inverse(damping_func)) self.assertAllClose( sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ops) - new_inv = sess.run(factor._inverses_by_damping[0]) + new_inv = sess.run(factor.get_inverse(damping_func)) self.assertAllClose(new_inv, np.linalg.inv(cov)) @@ -349,6 +377,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) def testFullFactorInitFloat64(self): @@ -357,6 +386,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullFactor((tensor,), 32) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 6], cov.get_shape().as_list()) @@ -366,6 +396,7 @@ class FullFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.FullFactor((tensor,), 2) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -379,7 +410,8 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + factor.instantiate_cov_variables() + self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -387,7 +419,8 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - cov = factor.get_cov() + factor.instantiate_cov_variables() + cov = factor.get_cov_var() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -396,12 +429,150 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.constant([1., 2.], name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 2) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[0.75], [1.5]], new_cov) +class EmbeddingInputKroneckerFactorTest(test.TestCase): + + def testInitialization(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() + cov = factor.get_cov_var() + self.assertEqual(cov.shape.as_list(), [vocab_size]) + + def testCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + factor.instantiate_cov_variables() + cov_update_op = factor.make_covariance_update_op(0.0) + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(cov_update_op) + self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) + + +class ConvDiagonalFactorTest(test.TestCase): + + def setUp(self): + self.batch_size = 10 + self.height = self.width = 32 + self.in_channels = 3 + self.out_channels = 1 + self.kernel_height = self.kernel_width = 3 + self.strides = [1, 2, 2, 1] + self.data_format = 'NHWC' + self.padding = 'SAME' + self.kernel_shape = [ + self.kernel_height, self.kernel_width, self.in_channels, + self.out_channels + ] + + def testInit(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format) + factor.instantiate_cov_variables() + + # Ensure covariance matrix's shape makes sense. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels, + self.out_channels + ], + factor.get_cov_var().shape.as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + # Construct all arguments such that convolution kernel is applied in + # exactly one spatial location. + inputs = np.random.randn( + 1, # batch_size + self.kernel_height, + self.kernel_width, + self.in_channels) # in_channels + outputs_grad = np.random.randn( + 1, # batch_size + 1, # output_height + 1, # output_width + self.out_channels) + + factor = ff.ConvDiagonalFactor( + (constant_op.constant(inputs),), + ((constant_op.constant(outputs_grad),),), + self.kernel_shape, + strides=[1, 1, 1, 1], + padding='VALID') + factor.instantiate_cov_variables() + + # Completely forget initial value on first update. + cov_update_op = factor.make_covariance_update_op(0.0) + + # Ensure new covariance value is same as outer-product of inputs/outputs + # vectorized, squared. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + cov = sess.run(cov_update_op) + expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 + self.assertAllClose(expected_cov, cov) + + def testHasBias(self): + with tf_ops.Graph().as_default(): + inputs = random_ops.random_uniform( + [self.batch_size, self.height, self.width, self.in_channels]) + outputs_grads = [ + random_ops.random_uniform([ + self.batch_size, self.height // self.strides[1], + self.width // self.strides[2], self.out_channels + ]) for _ in range(3) + ] + + factor = ff.ConvDiagonalFactor( + (inputs,), + (outputs_grads,), + self.kernel_shape, + self.strides, + self.padding, + data_format=self.data_format, + has_bias=True) + factor.instantiate_cov_variables() + + # Ensure shape accounts for bias. + self.assertEqual([ + self.kernel_height * self.kernel_width * self.in_channels + 1, + self.out_channels + ], + factor.get_cov_var().shape.as_list()) + + # Ensure update op doesn't crash. + cov_update_op = factor.make_covariance_update_op(0.0) + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(cov_update_op) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, @@ -411,7 +582,8 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual(final_shape, cov.get_shape().as_list()) @@ -428,7 +600,8 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -438,40 +611,171 @@ class FullyConnectedKroneckerFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor((tensor,)) + factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) -class ConvInputKroneckerFactorTest(test.TestCase): +class ConvFactorTestCase(test.TestCase): + + def assertMatrixRank(self, rank, matrix, atol=1e-5): + assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' + eigvals = np.linalg.eigvals(matrix) + nnz_eigvals = np.sum(eigvals > atol) + self.assertEqual( + rank, + nnz_eigvals, + msg=('Found %d of %d expected non-zero eigenvalues: %s.' % + (nnz_eigvals, rank, eigvals))) + + +class ConvInputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**3 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, width, in_channels), seed=0),), + filter_shape=(width, width, width, in_channels, out_channels), + padding='SAME', + strides=(2, 2, 2), + extract_patches_fn='extract_convolution_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + input_size = in_channels * (width**3) + self.assertEqual([input_size, input_size], + factor.get_cov_var().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank-8, as the filter will be applied at each corner of + # the 4-D cube. + self.assertMatrixRank(8, cov) + + def testPointwiseConv2d(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 1, 1, 1), + extract_patches_fn='extract_pointwise_conv2d_patches', + has_bias=False) + factor.instantiate_cov_variables() + + # Ensure shape of covariance matches input size of filter. + self.assertEqual([in_channels, in_channels], + factor.get_cov_var().shape.as_list()) + + # Ensure cov_update_op doesn't crash. + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank-9, as the filter will be applied at each location. + self.assertMatrixRank(9, cov) + + def testStrides(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 3**2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(1, 1, in_channels, out_channels), + padding='SAME', + strides=(1, 2, 1, 1), + extract_patches_fn='extract_image_patches', + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be the sum of 3 * 2 = 6 outer products. + self.assertMatrixRank(6, cov) + + def testDilationRate(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + in_channels = 2 + out_channels = 4 + + factor = ff.ConvInputKroneckerFactor( + inputs=(random_ops.random_uniform( + (batch_size, width, width, in_channels), seed=0),), + filter_shape=(3, 3, in_channels, out_channels), + padding='SAME', + extract_patches_fn='extract_image_patches', + strides=(1, 1, 1, 1), + dilation_rate=(1, width, width, 1), + has_bias=False) + factor.instantiate_cov_variables() + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov_var()) + + # Cov should be rank = in_channels, as only the center of the filter + # receives non-zero input for each input channel. + self.assertMatrixRank(in_channels, cov) def testConvInputKroneckerFactorInitNoBias(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=False) + inputs=(tensor,), + filter_shape=(1, 2, 3, 4), + padding='SAME', + has_bias=False) + factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3, 1 * 2 * 3], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) def testConvInputKroneckerFactorInitFloat64(self): with tf_ops.Graph().as_default(): dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], @@ -479,37 +783,67 @@ class ConvInputKroneckerFactorTest(test.TestCase): def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) factor = ff.ConvInputKroneckerFactor( - tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True) + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]], - new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose( + [ + [(1. + 4.) / 2., (1. + 2.) / 2.], # + [(1. + 2.) / 2., (1. + 1.) / 2.] + ], # + new_cov) def testMakeCovarianceUpdateOpNoBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) + input_shape = (2, 1, 1, 1) tensor = array_ops.constant( - np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) - factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1], - 'SAME') + np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( + np.float32)) + factor = ff.ConvInputKroneckerFactor( + (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[34.375, 37], [37, 41]], new_cov) + new_cov = sess.run(factor.make_covariance_update_op(0.)) + self.assertAllClose([[(1. + 4.) / 2.]], new_cov) + + +class ConvOutputKroneckerFactorTest(ConvFactorTestCase): + + def test3DConvolution(self): + with tf_ops.Graph().as_default(): + batch_size = 1 + width = 3 + out_channels = width**3 + factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ + random_ops.random_uniform( + (batch_size, width, width, width, out_channels), seed=0) + ],)) + factor.instantiate_cov_variables() -class ConvOutputKroneckerFactorTest(test.TestCase): + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + sess.run(factor.make_covariance_update_op(0.0)) + cov = sess.run(factor.get_cov()) + + # Cov should be rank 3^3, as each spatial position donates a rank-1 + # update. + self.assertMatrixRank(width**3, cov) def testConvOutputKroneckerFactorInit(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) def testConvOutputKroneckerFactorInitFloat64(self): @@ -517,23 +851,18 @@ class ConvOutputKroneckerFactorTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') - factor = ff.ConvOutputKroneckerFactor((tensor,)) + factor = ff.ConvOutputKroneckerFactor(((tensor,),)) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([5, 5], cov.get_shape().as_list()) - def testConvOutputKroneckerFactorInitNotEnoughDims(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - with self.assertRaises(IndexError): - ff.ConvOutputKroneckerFactor(tensor) - def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) - factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),)) + factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -546,8 +875,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) def testFullyConnectedMultiKFInitFloat64(self): @@ -555,8 +884,8 @@ class FullyConnectedMultiKFTest(test.TestCase): dtype = dtypes.float64_ref random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) + factor.instantiate_cov_variables() cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([3, 3], cov.get_shape().as_list()) @@ -565,8 +894,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) @@ -576,8 +905,8 @@ class FullyConnectedMultiKFTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - tensor_list = [tensor] - factor = ff.FullyConnectedMultiKF((tensor_list,)) + factor = ff.FullyConnectedMultiKF(((tensor,),)) + factor.instantiate_cov_variables() sess.run(tf_variables.global_variables_initializer()) new_cov = sess.run(factor.make_covariance_update_op(.5)) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index b8ccbeadd0a9d69edb41fef50e3edb090457adf2..cb80fca3705308f92e308e2a840336fb72d0fa62 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -35,7 +35,7 @@ from tensorflow.python.platform import test class MockFisherBlock(object): """A fake FisherBlock.""" - num_registered_minibatches = 2 + num_registered_towers = 2 def __init__(self, name='MockFisherBlock'): self.name = name @@ -104,22 +104,53 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5))) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], - 'SAME', - array_ops.ones((1, 1, 1, 1)), - array_ops.constant(3), + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=array_ops.ones((1, 2, 3, 4)), + outputs=array_ops.ones((1, 1, 1, 5)), approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_separable_conv2d( + depthwise_params=array_ops.ones((3, 3, 1, 2)), + pointwise_params=array_ops.ones((1, 1, 2, 4)), + inputs=array_ops.ones((32, 5, 5, 1)), + depthwise_outputs=array_ops.ones((32, 5, 5, 2)), + pointwise_outputs=array_ops.ones((32, 5, 5, 4)), + strides=[1, 1, 1, 1], + padding='SAME') + lc.register_convolution( + params=array_ops.ones((3, 3, 1, 8)), + inputs=array_ops.ones((32, 5, 5, 1)), + outputs=array_ops.ones((32, 5, 5, 8)), + padding='SAME') lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) lc.register_generic( array_ops.constant(6), 16, approx=layer_collection.APPROX_DIAGONAL_NAME) - - self.assertEqual(6, len(lc.get_blocks())) + lc.register_fully_connected_multi( + array_ops.constant(1), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + lc.register_conv2d_multi( + params=array_ops.ones((2, 3, 4, 5)), + strides=[1, 1, 1, 1], + padding='SAME', + inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), + outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) + lc.register_embedding_multi( + array_ops.constant((1,)), + (array_ops.constant(2), array_ops.constant(3)), + (array_ops.constant(4), array_ops.constant(5))) + + self.assertEqual(12, len(lc.get_blocks())) def testRegisterBlocksMultipleRegistrations(self): with ops.Graph().as_default(): @@ -237,16 +268,16 @@ class LayerCollectionTest(test.TestCase): # Create a new loss function by name. lc.register_categorical_predictive_distribution(logits, name='loss1') - self.assertEqual(1, len(lc.losses)) + self.assertEqual(1, len(lc.towers_by_loss)) # Add logits to same loss function. lc.register_categorical_predictive_distribution( logits, name='loss1', reuse=True) - self.assertEqual(1, len(lc.losses)) + self.assertEqual(1, len(lc.towers_by_loss)) # Add another new loss function. lc.register_categorical_predictive_distribution(logits, name='loss2') - self.assertEqual(2, len(lc.losses)) + self.assertEqual(2, len(lc.towers_by_loss)) def testLossFunctionWithoutName(self): """Ensure loss functions get unique names if 'name' not specified.""" @@ -298,13 +329,9 @@ class LayerCollectionTest(test.TestCase): name='loss1', reuse=layer_collection.VARIABLE_SCOPE) - self.assertEqual(len(lc.losses), 1) - loss = lc.losses[0] - + self.assertEqual(len(lc.towers_by_loss), 1) # Three successful registrations. - self.assertEqual(loss.params.shape.as_list(), - [3 * batch_size, output_size]) - self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size]) + self.assertEqual(len(lc.towers_by_loss[0]), 3) def testRegisterCategoricalPredictiveDistributionBatchSize1(self): with ops.Graph().as_default(): @@ -441,13 +468,13 @@ class LayerCollectionTest(test.TestCase): b = variable_scope.get_variable('b', [3]) lc = layer_collection.LayerCollection() lc.register_fully_connected(w, inputs, outputs) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) with self.assertRaises(KeyError): lc.register_fully_connected((w, b), inputs, outputs, reuse=True) self.assertNotIn((w, b), lc.fisher_blocks) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) lc.register_fully_connected(w, inputs, outputs, reuse=True) - self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) def testMakeOrGetFactor(self): with ops.Graph().as_default(): @@ -479,17 +506,6 @@ class LayerCollectionTest(test.TestCase): variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) self.assertTrue(all([var.name.startswith(scope) for var in variables])) - def testGetUseCountMap(self): - """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - 'a': MockFisherBlock(), - ('a', 'c'): MockFisherBlock(), - ('b', 'c'): MockFisherBlock() - } - use_count_map = lc.get_use_count_map() - self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) - def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', shape=()) y = variable_scope.get_variable('y', shape=()) @@ -550,6 +566,32 @@ class LayerCollectionTest(test.TestCase): self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + def testDefaultLayerCollection(self): + with ops.Graph().as_default(): + # Can't get default if there isn't one set. + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # Can't set default twice. + lc = layer_collection.LayerCollection() + layer_collection.set_default_layer_collection(lc) + with self.assertRaises(ValueError): + layer_collection.set_default_layer_collection(lc) + + # Same as one set. + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + + # Can set to None. + layer_collection.set_default_layer_collection(None) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + + # as_default() is the same as setting/clearing. + with lc.as_default(): + self.assertTrue(lc is layer_collection.get_default_layer_collection()) + with self.assertRaises(ValueError): + layer_collection.get_default_layer_collection() + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index ae787b6f1ac90218f2ac73d37fb270df0b822de2..c00af5593f085e3b1f3e030a24f4b821115cc869 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,7 +24,6 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -97,22 +96,6 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) - def testMultiMinibatchRegistration(self): - """Ensure this loss function supports registering multiple minibatches.""" - with ops.Graph().as_default(): - tower_logits = [] - loss = None - num_towers = 5 - for _ in range(num_towers): - logits = random_ops.random_uniform(shape=[2, 3]) - tower_logits.append(logits) - if loss is None: - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - else: - loss.register_additional_minibatch(logits) - self.assertListEqual(loss.input_minibatches, tower_logits) - self.assertEqual(loss.num_registered_minibatches, num_towers) - def testMultiplyFisherSingleVector(self): with ops.Graph().as_default(), self.test_session() as sess: logits = np.array([1., 2., 3.]) @@ -203,23 +186,5 @@ class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) - def testMultiMinibatchRegistration(self): - """Ensure this loss function supports registering multiple minibatches.""" - with ops.Graph().as_default(): - tower_logits = [] - loss = None - num_towers = 5 - for _ in range(num_towers): - logits = random_ops.random_uniform(shape=[2, 3]) - tower_logits.append(logits) - if loss is None: - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - logits) - else: - loss.register_additional_minibatch(logits) - self.assertListEqual(loss.input_minibatches, tower_logits) - self.assertEqual(loss.num_registered_minibatches, num_towers) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 97a97adbf5577cd2694d3055acaa59258ad27964..2cee01212a11595669e9df0fc95a5657926c1038 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -325,6 +327,84 @@ class UtilsTest(test.TestCase): ], values) + def testExtractConvolutionPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_spatial_shape = [9, 10, 11] + in_channels = out_channels = 32 + kernel_spatial_shape = [5, 3, 3] + spatial_strides = [1, 2, 1] + spatial_dilation = [1, 1, 1] + padding = 'SAME' + + images = random_ops.random_uniform( + [batch_size] + image_spatial_shape + [in_channels], seed=0) + kernel_shape = kernel_spatial_shape + [in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_convolution_patches( + images, + kernel_shape, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + result_spatial_shape = ( + patches.shape.as_list()[1:1 + len(image_spatial_shape)]) + self.assertEqual(patches.shape.as_list(), + [batch_size] + result_spatial_shape + + kernel_spatial_shape + [in_channels]) + + # Ensure extract...patches() + matmul() and convolution() implementation + # give the same answer. + outputs = nn_ops.convolution( + images, + kernel, + padding, + strides=spatial_strides, + dilation_rate=spatial_dilation) + + patches_flat = array_ops.reshape( + patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + + def testExtractPointwiseConv2dPatches(self): + with ops.Graph().as_default(), self.test_session() as sess: + batch_size = 10 + image_height = image_width = 8 + in_channels = out_channels = 3 + kernel_height = kernel_width = 1 + strides = [1, 1, 1, 1] + padding = 'VALID' + + images = random_ops.random_uniform( + [batch_size, image_height, image_width, in_channels], seed=0) + kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] + kernel = random_ops.random_uniform(kernel_shape, seed=1) + + # Ensure shape matches expectation. + patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) + self.assertEqual(patches.shape.as_list(), [ + batch_size, image_height, image_width, kernel_height, kernel_width, + in_channels + ]) + + # Ensure extract...patches() + matmul() and conv2d() implementation + # give the same answer. + outputs = nn_ops.conv2d(images, kernel, strides, padding) + + patches_flat = array_ops.reshape( + patches, [-1, kernel_height * kernel_width * in_channels]) + kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) + outputs_flat = math_ops.matmul(patches_flat, kernel_flat) + + outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) + self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index ee6549b109399766579b6ea18a987ae2c8275983..b897fd68a080e819042cd36f2a1acfcf175e656b 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -144,10 +144,13 @@ py_library( ":fisher_estimator", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], ) @@ -168,6 +171,7 @@ py_library( name = "fisher_estimator", srcs = [ "estimator.py", + "placement.py", ], srcs_version = "PY2AND3", deps = [ @@ -177,6 +181,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", + "@six_archive//:six", ], ) @@ -239,15 +244,3 @@ py_library( "//tensorflow/python:util", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index a7b1f9d35c931fc44408be804479e758f28f7110..d11c9c828810742cd176e4c5b7b77cf9a5cf87d9 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,68 +18,59 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import contextlib -import itertools - +import abc import numpy as np +import six +from tensorflow.contrib.kfac.python.ops import placement from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest -class _DeviceContextGenerator(object): - """Class for generating device contexts in a round-robin fashion.""" - - def __init__(self, devices): - """Creates a _DeviceContextGenerator object. +# The linter is confused. +# pylint: disable=abstract-class-instantiated +def make_fisher_estimator(placement_strategy=None, **kwargs): + """Creates Fisher estimator instances based on the placement strategy. - Example usage: + For example if the `placement_strategy` is 'round_robin' then + `FisherEstimatorRoundRobin` instance is returned. - ```python - dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) - with dcg(): - # All operations in this context will be placed on GPU 0 - ... - with dcg(): - # All operations in this context will be placed on GPU 1 - ... - ``` + Args: + placement_strategy: `string`, Strategy to be used for placing covariance + variables, covariance ops and inverse ops. Check + `placement.FisherEstimatorRoundRobin` for a concrete example. + **kwargs: Arguments to be passed into `FisherEstimator` class initializer. - Args: - devices: An iterable of device strings (or None). Successive calls to - __call__ will give contexts which place devices on these devices in - a round-robin fashion. - """ - self._cycle = None if devices is None else itertools.cycle(devices) + Returns: + An instance of class which inherits from `FisherEstimator` and the mixin + which implements specific placement strategy. See, + `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and + `RoundRobinPlacementMixin`. - @contextlib.contextmanager - def __call__(self): - """Returns a context manager specifying the default device.""" - if self._cycle is None: - yield - else: - with tf_ops.device(next(self._cycle)): - yield + Raises: + ValueError: If the `placement_strategy` is not equal to 'round_robin'. + """ + if placement_strategy in [None, "round_robin"]: + return FisherEstimatorRoundRobin(**kwargs) + else: + raise ValueError("Unimplemented vars and ops placement strategy : %s", + placement_strategy) +# pylint: enable=abstract-class-instantiated +@six.add_metaclass(abc.ABCMeta) class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher. - Attributes: - cov_update_thunks: list of no-arg functions. Executing a function adds - covariance update ops for a single FisherFactor to the graph. - cov_update_ops: List of Ops. Running an op updates covariance matrices for a - single FisherFactor. - cov_update_op: Op. Running updates covariance matrices for all - FisherFactors. - inv_update_thunks: list of no-arg functions. Executing a function adds - inverse update ops for a single FisherFactor to the graph. - inv_update_ops: List of Ops. Running an op updates inverse matrices for a - single FisherFactor. - inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + This is an abstract base class which does not implement a strategy for + placing covariance variables, covariance update ops and inverse update ops. + The placement strategies are implemented in `placement.py`. See + `FisherEstimatorRoundRobin` for example of a concrete subclass with + a round-robin placement strategy. """ def __init__(self, @@ -87,26 +78,31 @@ class FisherEstimator(object): cov_ema_decay, damping, layer_collection, + exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, - cov_devices=None, - inv_devices=None): + name="FisherEstimator"): """Create a FisherEstimator object. Args: - variables: A list of the variables for which to estimate the Fisher. This - must match the variables registered in layer_collection (if it is not - None). + variables: A `list` of variables or `callable` which returns the variables + for which to estimate the Fisher. This must match the variables + registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. - damping: The damping factor used to stabilize training due to errors in - the local approximation with the Fisher information matrix, and to - regularize the update direction by making it closer to the gradient. - (Higher damping means the update looks more like a standard gradient - update - see Tikhonov regularization.) + damping: float. The damping factor used to stabilize training due to + errors in the local approximation with the Fisher information matrix, + and to regularize the update direction by making it closer to the + gradient. (Higher damping means the update looks more like a standard + gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. + exps: List of floats or ints. These represent the different matrix + powers of the approximate Fisher that the FisherEstimator will be able + to multiply vectors by. If the user asks for a matrix power other + one of these (or 1, which is always supported), there will be a + failure. (Default: (-1,)) estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach @@ -125,24 +121,17 @@ class FisherEstimator(object): equal to the output dimension, roughly speaking. colocate_gradients_with_ops: Whether we should request gradients be colocated with their respective ops. (Default: True) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - + name: A string. A name given to this estimator, which is added to the + variable scope when constructing variables and ops. + (Default: "FisherEstimator") Raises: ValueError: If no losses have been registered with layer_collection. """ - - self._cov_ema_decay = cov_ema_decay self._variables = variables + self._cov_ema_decay = cov_ema_decay self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection - self._layers.create_subgraph() - self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, @@ -151,39 +140,107 @@ class FisherEstimator(object): } self._colocate_gradients_with_ops = colocate_gradients_with_ops - # TODO(b/70674513): Factor device placement outside of this class. - self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) - if inv_devices == cov_devices: - self._inv_device_context_generator = self._cov_device_context_generator - else: - self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) - - self._instantiate_factors() + self._made_vars = False + self._exps = exps - self.cov_update_thunks = [ - self._create_cov_update_thunk(factor) - for factor in self._layers.get_factors() - ] - self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks] - self.cov_update_op = control_flow_ops.group( - self.cov_update_ops, name="cov_update_op") - - self.inv_update_thunks = [ - self._create_inv_update_thunk(factor) - for factor in self._layers.get_factors() - ] - self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks] - self.inv_update_op = control_flow_ops.group( - self.inv_update_ops, name="inv_update_op") + self._name = name @property def variables(self): - return self._variables + if callable(self._variables): + return self._variables() + else: + return self._variables @property def damping(self): return self._damping + @property + def blocks(self): + """All registered FisherBlocks.""" + return self._layers.get_blocks() + + @property + def factors(self): + """All registered FisherFactors.""" + return self._layers.get_factors() + + @property + def name(self): + return self._name + + @abc.abstractmethod + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. For example in case of + round robin placement a new device is chosen for each factor by cycling + through list of devices in the cov_devices argument. If cov_devices is None + then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + pass + + @abc.abstractmethod + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks with a specific placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the cov_devices + argument. If cov_devices is None then no explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the inv_devices argument. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + pass + def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. @@ -217,9 +274,7 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ - - return self._apply_transformation(vecs_and_vars, - lambda fb, vec: fb.multiply_inverse(vec)) + return self.multiply_matpower(-1, vecs_and_vars) def multiply(self, vecs_and_vars): """Multiplies the vectors by the corresponding (damped) blocks. @@ -231,9 +286,22 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ + return self.multiply_matpower(1, vecs_and_vars) - return self._apply_transformation(vecs_and_vars, - lambda fb, vec: fb.multiply(vec)) + def multiply_matpower(self, exp, vecs_and_vars): + """Multiplies the vecs by the corresponding matrix powers of the blocks. + + Args: + exp: A float representing the power to raise the blocks by before + multiplying it by the vector. + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) + return self._apply_transformation(vecs_and_vars, fcn) def _instantiate_factors(self): """Instantiates FisherFactors' variables. @@ -241,9 +309,9 @@ class FisherEstimator(object): Raises: ValueError: If estimation_mode was improperly specified at construction. """ - fisher_blocks_list = self._layers.get_blocks() + blocks = self.blocks tensors_to_compute_grads = [ - fb.tensors_to_compute_grads() for fb in fisher_blocks_list + block.tensors_to_compute_grads() for block in blocks ] try: @@ -253,45 +321,131 @@ class FisherEstimator(object): raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) - # TODO(b/68033310): This loop round-robins the "concat" operations which - # gather the inputs for the cov_updates. In future, we might do these - # computations locally then communicate the results, which would require a - # modification to this code. - for grads_list, fb in zip(grads_lists, fisher_blocks_list): - with self._cov_device_context_generator(): - fb.instantiate_factors(grads_list, self.damping) + for grads_list, block in zip(grads_lists, blocks): + block.instantiate_factors(grads_list, self.damping) + + def _check_vars_unmade_and_set_made_flag(self): + if self._made_vars: + raise Exception("Already made variables.") + self._made_vars = True + + def made_vars(self): + return self._made_vars + + def _register_matrix_functions(self): + for exp in self._exps: + for block in self.blocks: + block.register_matpower(exp) - def _create_cov_update_thunk(self, factor): + def _finalize_layer_collection(self): + self._layers.create_subgraph() + self._layers.check_registration(self.variables) + self._instantiate_factors() + self._register_matrix_functions() + + def create_ops_and_vars_thunks(self, scope=None): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All thunks will execute inside + of a variable scope of the given name. (Default: None) + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + self._check_vars_unmade_and_set_made_flag() + + self._finalize_layer_collection() + + scope = self.name if scope is None else scope + + cov_variable_thunks = [ + self._create_cov_variable_thunk(factor, scope) + for factor in self.factors + ] + cov_update_thunks = [ + self._create_cov_update_thunk(factor, scope) for factor in self.factors + ] + inv_variable_thunks = [ + self._create_inv_variable_thunk(factor, scope) + for factor in self.factors + ] + inv_update_thunks = [ + self._create_inv_update_thunk(factor, scope) for factor in self.factors + ] + + return (cov_variable_thunks, cov_update_thunks, + inv_variable_thunks, inv_update_thunks) + + def _create_cov_variable_thunk(self, factor, scope): + """Constructs a covariance variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_cov_variables() + + return thunk + + def _create_cov_update_thunk(self, factor, scope): """Constructs a covariance update thunk for a single FisherFactor.""" def thunk(): - with tf_ops.name_scope( - "create_cov_update_thunk", values=[self._cov_ema_decay]): + with variable_scope.variable_scope(scope): return factor.make_covariance_update_op(self._cov_ema_decay) return thunk - def _create_inv_update_thunk(self, factor): + def _create_inv_variable_thunk(self, factor, scope): + """Constructs a inverse variable thunk for a single FisherFactor.""" + + def thunk(): + with variable_scope.variable_scope(scope): + return factor.instantiate_inv_variables() + + return thunk + + def _create_inv_update_thunk(self, factor, scope): """Constructs an inverse update thunk for a single FisherFactor.""" def thunk(): - with tf_ops.name_scope("create_inv_update_thunk"): - with self._inv_device_context_generator(): - return control_flow_ops.group(factor.make_inverse_update_ops()) + with variable_scope.variable_scope(scope): + return control_flow_ops.group(factor.make_inverse_update_ops()) return thunk def _get_grads_lists_gradients(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device grads_flat = gradients_impl.gradients( - self._layers.total_sampled_loss(), + self._layers.eval_losses_on_samples(), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): + # Passing in a list of loss values is better than passing in the sum as + # the latter creates unnessesary ops on the default device grads_flat = gradients_impl.gradients( - self._layers.total_loss(), + self._layers.eval_losses(), nest.flatten(tensors), colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) @@ -300,9 +454,10 @@ class FisherEstimator(object): def _get_transformed_random_signs(self): transformed_random_signs = [] for loss in self._layers.losses: - transformed_random_signs.append( - loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape))) + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) return transformed_random_signs def _get_grads_lists_curvature_prop(self, tensors): @@ -321,13 +476,20 @@ class FisherEstimator(object): # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients( - loss.inputs, - nest.flatten(tensors), - grad_ys=transformed_one_hot, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) + + +class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, + FisherEstimator): + """Fisher estimator which provides round robin device placement strategy.""" + pass diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 0d2fa706f5853570bb8c04a9b9ac3378e2f2386e..00b3673a742e92057b0a1673d3f42a19379111fe 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -19,11 +19,11 @@ Information matrix. Suppose one has a model that parameterizes a posterior distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its Fisher Information matrix is given by, - F(params) = E[ v(x, y, params) v(x, y, params)^T ] + $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ where, - v(x, y, params) = (d / d params) log p(y | x, params) + $$v(x, y, params) = (d / d params) log p(y | x, params)$$ and the expectation is taken with respect to the data's distribution for 'x' and the model's posterior distribution for 'y', @@ -40,12 +40,15 @@ from __future__ import print_function import abc import enum # pylint: disable=g-bad-import-order +import numpy as np import six from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest # For blocks corresponding to convolutional layers, or any type of block where # the parameters can be thought of as being replicated in time or space, @@ -82,7 +85,7 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): """Computes the scalar constant pi for Tikhonov regularization/damping. - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: @@ -92,10 +95,22 @@ def compute_pi_tracenorm(left_cov, right_cov): Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ + + def _trace(cov): + if len(cov.shape) == 1: + # Diagonal matrix. + return math_ops.reduce_sum(cov) + elif len(cov.shape) == 2: + # Full matrix. + return math_ops.trace(cov) + else: + raise ValueError( + "What's the trace of a Tensor of rank %d?" % len(cov.shape)) + # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] return math_ops.sqrt(left_norm / right_norm) @@ -109,12 +124,44 @@ def compute_pi_adjusted_damping(left_cov, right_cov, damping): return (damping, damping) +class PackagedFunc(object): + """A Python thunk with a stable ID. + + Enables stable names for lambdas. + """ + + def __init__(self, func, func_id): + """Initializes PackagedFunc. + + Args: + func: a zero-arg Python function. + func_id: a hashable, function that produces a hashable, or a list/tuple + thereof. + """ + self._func = func + func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) + self._func_id = func_id + + def __call__(self): + return self._func() + + @property + def func_id(self): + """A hashable identifier for this function.""" + return tuple(elt() if callable(elt) else elt for elt in self._func_id) + + +def _package_func(func, func_id): + return PackagedFunc(func, func_id) + + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): """Abstract base class for objects modeling approximate Fisher matrix blocks. - Subclasses must implement multiply_inverse(), instantiate_factors(), and - tensors_to_compute_grads() methods. + Subclasses must implement register_matpower, multiply_matpower, + instantiate_factors, tensors_to_compute_grads, and num_registered_towers + methods. """ def __init__(self, layer_collection): @@ -133,6 +180,32 @@ class FisherBlock(object): pass @abc.abstractmethod + def register_matpower(self, exp): + """Registers a matrix power to be computed by the block. + + Args: + exp: A float representing the power to raise the block by. + """ + pass + + def register_inverse(self): + """Registers a matrix inverse to be computed by the block.""" + self.register_matpower(-1) + + @abc.abstractmethod + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + exp: A float representing the power to raise the block by before + multiplying it by the vector. + + Returns: + The vector left-multiplied by the (damped) matrix-power of the block. + """ + pass + def multiply_inverse(self, vector): """Multiplies the vector by the (damped) inverse of the block. @@ -142,9 +215,8 @@ class FisherBlock(object): Returns: The vector left-multiplied by the (damped) inverse of the block. """ - pass + return self.multiply_matpower(vector, -1) - @abc.abstractmethod def multiply(self, vector): """Multiplies the vector by the (damped) block. @@ -154,7 +226,7 @@ class FisherBlock(object): Returns: The vector left-multiplied by the (damped) block. """ - pass + return self.multiply_matpower(vector, 1) @abc.abstractmethod def tensors_to_compute_grads(self): @@ -163,8 +235,8 @@ class FisherBlock(object): pass @abc.abstractproperty - def num_registered_minibatches(self): - """Number of minibatches registered for this FisherBlock. + def num_registered_towers(self): + """Number of towers registered for this FisherBlock. Typically equal to the number of towers in a multi-tower setup. """ @@ -195,21 +267,18 @@ class FullFB(FisherBlock): super(FullFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._damping = damping + self._damping_func = _package_func(lambda: damping, (damping,)) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullFactor, (grads_list, self._batch_size)) - self._factor.register_damped_inverse(damping) - def multiply_inverse(self, vector): - inverse = self._factor.get_damped_inverse(self._damping) - out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) - return utils.column_to_tensors(vector, out_flat) + def register_matpower(self, exp): + self._factor.register_matpower(exp, self._damping_func) - def multiply(self, vector): + def multiply_matpower(self, vector, exp): vector_flat = utils.tensors_to_column(vector) - out_flat = ( - math_ops.matmul(self._factor.get_cov(), vector_flat) + - self._damping * vector_flat) + out_flat = self._factor.left_multiply_matpower( + vector_flat, exp, self._damping_func) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -219,8 +288,8 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -228,7 +297,7 @@ class FullFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -259,28 +328,30 @@ class NaiveDiagonalFB(FisherBlock): super(NaiveDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - self._damping = damping + self._damping_func = _package_func(lambda: damping, (damping,)) + self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - def multiply_inverse(self, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat / (self._factor.get_cov() + self._damping) - return utils.column_to_tensors(vector, out_flat) + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass - def multiply(self, vector): + def multiply_matpower(self, vector, exp): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat * (self._factor.get_cov() + self._damping) + out_flat = self._factor.left_multiply_matpower( + vector_flat, exp, self._damping_func) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): - return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,))) + return self._factor.get_cov() def tensors_to_compute_grads(self): return self._params - def register_additional_minibatch(self, batch_size): - """Register an additional minibatch. + def register_additional_tower(self, batch_size): + """Register an additional tower. Args: batch_size: The batch size, used in the covariance estimator. @@ -288,7 +359,7 @@ class NaiveDiagonalFB(FisherBlock): self._batch_sizes.append(batch_size) @property - def num_registered_minibatches(self): + def num_registered_towers(self): return len(self._batch_sizes) @property @@ -296,7 +367,92 @@ class NaiveDiagonalFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class FullyConnectedDiagonalFB(FisherBlock): +class InputOutputMultiTower(object): + """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" + + def __init__(self, *args, **kwargs): + self.__inputs = [] + self.__outputs = [] + super(InputOutputMultiTower, self).__init__(*args, **kwargs) + + def _process_data(self, grads_list): + """Process data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + The initial format of self._inputs is expected to be a list of Tensors + over towers. Similarly grads_lists is expected to be a list over sources + of such lists. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single + tensor (represented as a PartitionedTensor object) equal to the + concatenation (across towers) of all of the elements of self._inputs. And + similarly grads_list is formatted into a tuple (over sources) of such + tensors (also represented as PartitionedTensors). + + If TOWER_STRATEGY is "separate", formatting of inputs and grads_list + remains unchanged from the initial format (although possibly converting + from lists into tuples). + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". + """ + inputs = self._inputs + # inputs is a list over towers of Tensors + # grads_list is a list of list with the first index being sources and the + # second being towers. + if fisher_factors.TOWER_STRATEGY == "concat": + # Merge towers together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + # Do the same for grads_list but preserve leading sources dimension + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + elif fisher_factors.TOWER_STRATEGY == "separate": + inputs = tuple(inputs) + grads_list = tuple(grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + return inputs, grads_list + + def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" + return tuple(self._outputs) + + def register_additional_tower(self, inputs, outputs): + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_towers(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + + @property + def _inputs(self): + return self.__inputs + + @property + def _outputs(self): + return self.__outputs + + +class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -306,14 +462,14 @@ class FullyConnectedDiagonalFB(FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider fully connected layer in this model with (unshared) weight matrix 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( a (d loss / d s)^T ) + $$v(x, y, w) = vec( a (d loss / d s)^T )$$ This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'. @@ -328,78 +484,46 @@ class FullyConnectedDiagonalFB(FisherBlock): has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = [] - self._outputs = [] self._has_bias = has_bias super(FullyConnectedDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) - self._damping = damping self._factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedDiagonalFactor, (inputs, grads_list, self._has_bias)) - def multiply_inverse(self, vector): - """Approximate damped inverse Fisher-vector product. - - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. + self._damping_func = _package_func(lambda: damping, (damping,)) - Returns: - Tensor of the same shape, corresponding to the inverse Fisher-vector - product. - """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass - def multiply(self, vector): - """Approximate damped Fisher-vector product. + def multiply_matpower(self, vector, exp): + """Multiplies the vector by the (damped) matrix-power of the block. Args: vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape [input_size, output_size] corresponding to layer's weights. If not, a 2-tuple of the former and a Tensor of shape [output_size] corresponding to the layer's bias. + exp: A scalar representing the power to raise the block before multiplying + it by the vector. Returns: - Tensor of the same shape, corresponding to the Fisher-vector product. + The vector left-multiplied by the (damped) matrix-power of the block. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply_matpower( + reshaped_vec, exp, self._damping_func) return utils.mat2d_to_layer_params(vector, reshaped_out) - def tensors_to_compute_grads(self): - """Tensors to compute derivative of loss with respect to.""" - return self._outputs - - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - result = len(self._inputs) - assert result == len(self._outputs) - return result - -class ConvDiagonalFB(FisherBlock): - """FisherBlock for convolutional layers using a diagonal approx. +class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): + """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" @@ -408,14 +532,14 @@ class ConvDiagonalFB(FisherBlock): Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is, - Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ] + $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ]$$ Consider a convoluational layer in this model with (unshared) filter matrix 'w'. For an example image 'x' that produces layer inputs 'a' and output preactivations 's', - v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ where 'loc' is a single (x, y) location in an image. @@ -423,7 +547,13 @@ class ConvDiagonalFB(FisherBlock): to the layer's parameters 'w'. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + strides, + padding, + data_format=None, + dilations=None): """Creates a ConvDiagonalFB block. Args: @@ -435,90 +565,115 @@ class ConvDiagonalFB(FisherBlock): containing the previous and a Tensor of shape [out_channels]. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (e.g. "SAME"). + data_format: str or None. Format of input data. + dilations: List of 4 ints or None. Rate for dilation along all dimensions. + + Raises: + ValueError: if strides is not length-4. + ValueError: if dilations is not length-4. + ValueError: if channel is not last dimension. """ - self._inputs = [] - self._outputs = [] - self._strides = tuple(strides) if isinstance(strides, list) else strides + if len(strides) != 4: + raise ValueError("strides must contain 4 numbers.") + + if dilations is None: + dilations = [1, 1, 1, 1] + + if len(dilations) != 4: + raise ValueError("dilations must contain 4 numbers.") + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + self._strides = maybe_tuple(strides) self._padding = padding + self._data_format = data_format + self._dilations = maybe_tuple(dilations) self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params self._filter_shape = tuple(fltr.shape.as_list()) + if len(self._filter_shape) != 4: + raise ValueError( + "Convolution filter must be of shape" + " [filter_height, filter_width, in_channels, out_channels].") + super(ConvDiagonalFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # Concatenate inputs, grads_list into single Tensors. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - inputs_shape = tuple(inputs.shape.as_list()) - self._num_locations = ( - inputs_shape[1] * inputs_shape[2] // - (self._strides[1] * self._strides[2])) - - self._damping = (self._num_locations - * normalize_damping(damping, self._num_locations)) + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, (inputs, grads_list, self._filter_shape, self._strides, self._padding, - self._has_bias)) - - def multiply_inverse(self, vector): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply(self, vector): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) - return utils.mat2d_to_layer_params(vector, reshaped_out) + self._data_format, self._dilations, self._has_bias)) - def tensors_to_compute_grads(self): - return self._outputs + def damping_func(): + return self._num_locations * normalize_damping(damping, + self._num_locations) - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. + damping_id = (self._num_locations, "mult", "normalize_damping", damping, + self._num_locations) + self._damping_func = _package_func(damping_func, damping_id) - Args: - inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to - the convolution. - outputs: Tensor of shape [batch_size, height, width, output_size]. Layer - preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass - @property - def num_registered_minibatches(self): - return len(self._inputs) + def multiply_matpower(self, vector, exp): + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply_matpower( + reshaped_vect, exp, self._damping_func) + return utils.mat2d_to_layer_params(vector, reshaped_out) class KroneckerProductFB(FisherBlock): - """A base class for FisherBlocks with separate input and output factors. + """A base class for blocks with separate input and output Kronecker factors. The Fisher block is approximated as a Kronecker product of the input and output factors. """ - def _register_damped_input_and_output_inverses(self, damping): - """Registers damped inverses for both the input and output factors. - - Sets the instance members _input_damping and _output_damping. Requires the - instance members _input_factor and _output_factor. + def __init__(self, layer_collection): + super(KroneckerProductFB, self).__init__(layer_collection) + + def _setup_damping(self, damping, normalization=None): + """Makes functions that compute the damping values for both factors.""" + def compute_damping(): + if normalization is not None: + maybe_normalized_damping = normalize_damping(damping, normalization) + else: + maybe_normalized_damping = damping + + return compute_pi_adjusted_damping(self._input_factor.get_cov(), + self._output_factor.get_cov(), + maybe_normalized_damping**0.5) + + if normalization is not None: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + "normalize_damping", damping, normalization, "power", 0.5) + else: + damping_id = ("compute_pi_adjusted_damping", + "cov", self._input_factor.name, + "cov", self._output_factor.name, + damping, "power", 0.5) - Args: - damping: The base damping factor (float or Tensor) for the damped inverse. - """ - self._input_damping, self._output_damping = compute_pi_adjusted_damping( - self._input_factor.get_cov(), - self._output_factor.get_cov(), - damping**0.5) + self._input_damping_func = _package_func(lambda: compute_damping()[0], + damping_id + ("ref", 0)) + self._output_damping_func = _package_func(lambda: compute_damping()[1], + damping_id + ("ref", 1)) - self._input_factor.register_damped_inverse(self._input_damping) - self._output_factor.register_damped_inverse(self._output_damping) + def register_matpower(self, exp): + self._input_factor.register_matpower(exp, self._input_damping_func) + self._output_factor.register_matpower(exp, self._output_damping_func) @property def _renorm_coeff(self): @@ -532,32 +687,15 @@ class KroneckerProductFB(FisherBlock): """ return 1.0 - def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_damped_inverse( - self._output_damping) - reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = math_ops.matmul(left_factor_inv, - math_ops.matmul(reshaped_vector, - right_factor_inv)) - if self._renorm_coeff != 1.0: - reshaped_out /= math_ops.cast( - self._renorm_coeff, dtype=reshaped_out.dtype) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply(self, vector): - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() + def multiply_matpower(self, vector, exp): reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = ( - math_ops.matmul(reshaped_vector, right_factor) + - self._output_damping * reshaped_vector) - reshaped_out = ( - math_ops.matmul(left_factor, reshaped_out) + - self._input_damping * reshaped_out) + reshaped_out = self._output_factor.right_multiply_matpower( + reshaped_vector, exp, self._output_damping_func) + reshaped_out = self._input_factor.left_multiply_matpower( + reshaped_out, exp, self._input_damping_func) if self._renorm_coeff != 1.0: - reshaped_out *= math_ops.cast( - self._renorm_coeff, dtype=reshaped_out.dtype) + renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) + reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) def full_fisher_block(self): @@ -574,7 +712,49 @@ class KroneckerProductFB(FisherBlock): right_factor) -class FullyConnectedKFACBasicFB(KroneckerProductFB): +class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): + """K-FAC FisherBlock for embedding layers. + + This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its + input factor is approximated by a diagonal matrix. In the case that each + example references exactly one embedding, this approximation is exact. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) + self._setup_damping(damping) + + +class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. This uses the Kronecker-factorized approximation from the original @@ -590,8 +770,6 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): has_bias: Whether the component Kronecker factors have an additive bias. (Default: False) """ - self._inputs = [] - self._outputs = [] self._has_bias = has_bias super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) @@ -606,42 +784,19 @@ class FullyConnectedKFACBasicFB(KroneckerProductFB): damping: 0-D Tensor or float. 'damping' * identity is approximately added to this FisherBlock's Fisher approximation. """ - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( # - fisher_factors.FullyConnectedKroneckerFactor, # + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) - self._register_damped_input_and_output_inverses(damping) + self._setup_damping(damping) - def tensors_to_compute_grads(self): - return self._outputs - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. - - Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to the - matrix-multiply. - outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. - """ - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_minibatches(self): - return len(self._inputs) - - -class ConvKFCBasicFB(KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. +class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): + """FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -650,12 +805,12 @@ class ConvKFCBasicFB(KroneckerProductFB): 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates, - F(w) = #locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T]) + $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T])$$ where - ds = (d / ds) log p(y | x, w) + $$ds = (d / ds) log p(y | x, w)$$ #locations = number of (x, y) locations where 'w' is applied. where the expectation is taken over all examples and locations and flat() @@ -664,23 +819,40 @@ class ConvKFCBasicFB(KroneckerProductFB): See equation 23 in https://arxiv.org/abs/1602.01407 for details. """ - def __init__(self, layer_collection, params, strides, padding): + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None): """Creates a ConvKFCBasicFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, + kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". """ - self._inputs = [] - self._outputs = [] - self._strides = tuple(strides) if isinstance(strides, list) else strides self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = isinstance(params, (tuple, list)) fltr = params[0] if self._has_bias else params @@ -689,145 +861,606 @@ class ConvKFCBasicFB(KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): - # TODO(b/68033310): Validate which of, - # (1) summing on a single device (as below), or - # (2) on each device in isolation and aggregating - # is faster. - inputs = _concat_along_batch_dim(self._inputs) - grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._strides, self._padding, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - damping = normalize_damping(damping, self._num_locations) - self._register_damped_input_and_output_inverses(damping) - self._damping = damping + self._setup_damping(damping, normalization=self._num_locations) @property def _renorm_coeff(self): return self._num_locations - def tensors_to_compute_grads(self): - return self._outputs - def register_additional_minibatch(self, inputs, outputs): - """Registers an additional minibatch to the FisherBlock. +class DepthwiseConvDiagonalFB(ConvDiagonalFB): + """FisherBlock for depthwise_conv2d(). + + Equivalent to ConvDiagonalFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. Args: - inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to - the convolution. - outputs: Tensor of shape [batch_size, height, width, output_size]. Layer - preactivations. + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. """ - self._inputs.append(inputs) - self._outputs.append(outputs) + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvDiagonalFB, self).__init__( + layer_collection=layer_collection, + params=params, + strides=strides, + padding=padding, + dilations=rate, + data_format=data_format) + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def multiply_matpower(self, vector, exp): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( + conv2d_vector, exp) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - @property - def num_registered_minibatches(self): - return len(self._inputs) +class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): + """FisherBlock for depthwise_conv2d(). -def _concat_along_batch_dim(tensor_list): - """Concatenate tensors along batch (first) dimension. + Equivalent to ConvKFCBasicFB applied to each input channel in isolation. + """ + + def __init__(self, + layer_collection, + params, + strides, + padding, + rate=None, + data_format=None): + """Creates a DepthwiseConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: Tensor of shape [filter_height, filter_width, in_channels, + channel_multiplier]. + strides: List of 4 ints. Strides along all dimensions. + padding: str. Padding method. + rate: List of 4 ints or None. Rate for dilation along all dimensions. + data_format: str or None. Format of input data. + + Raises: + NotImplementedError: If parameters contains bias. + ValueError: If filter is not 4-D. + ValueError: If strides is not length-4. + ValueError: If rates is not length-2. + ValueError: If channels are not last dimension. + """ + if isinstance(params, (tuple, list)): + raise NotImplementedError("Bias not yet supported.") + + if params.shape.ndims != 4: + raise ValueError("Filter must be 4-D.") + + if len(strides) != 4: + raise ValueError("strides must account for 4 dimensions.") + + if rate is not None: + if len(rate) != 2: + raise ValueError("rate must only account for spatial dimensions.") + rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. + + if not utils.is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels-last.") + + super(DepthwiseConvKFCBasicFB, self).__init__( + layer_collection=layer_collection, + params=params, + padding=padding, + strides=strides, + dilation_rate=rate, + data_format=data_format, + extract_patches_fn="extract_image_patches") + + # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). + filter_height, filter_width, in_channels, channel_multiplier = ( + params.shape.as_list()) + self._filter_shape = (filter_height, filter_width, in_channels, + in_channels * channel_multiplier) + + def multiply_matpower(self, vector, exp): + conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) + conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( + conv2d_vector, exp) + return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) + + +def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with conv2d. + + Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's + compatible with tf.nn.conv2d(). Args: - tensor_list: list of Tensors or list of tuples of Tensors. + filter: Tensor of shape [height, width, in_channels, channel_multiplier]. + name: None or str. Name of Op. Returns: - Tensor or tuple of Tensors. + Tensor of shape [height, width, in_channels, out_channels]. - Raises: - ValueError: If 'tensor_list' is empty. + """ + with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, channel_multiplier = ( + filter.shape.as_list()) + + results = [] + for i in range(in_channels): + # Slice out one in_channel's filter. Insert zeros around it to force it + # to affect that channel and that channel alone. + elements = [] + if i > 0: + elements.append( + array_ops.zeros( + [filter_height, filter_width, i, channel_multiplier])) + elements.append(filter[:, :, i:(i + 1), :]) + if i + 1 < in_channels: + elements.append( + array_ops.zeros([ + filter_height, filter_width, in_channels - (i + 1), + channel_multiplier + ])) + + # Concat along in_channel. + results.append( + array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) + + # Concat along out_channel. + return array_ops.concat(results, axis=-1, name="out_channel") + + +def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin + """Converts a convolution filter for use with depthwise_conv2d. + + Transforms a filter for use with tf.nn.conv2d() to one that's + compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along + the diagonal. + Args: + filter: Tensor of shape [height, width, in_channels, out_channels]. + name: None or str. Name of Op. + + Returns: + Tensor of shape, + [height, width, in_channels, channel_multiplier] + + Raises: + ValueError: if out_channels is not evenly divisible by in_channels. """ - if not tensor_list: - raise ValueError( - "Cannot concatenate Tensors if there are no Tensors to concatenate.") - - if isinstance(tensor_list[0], (tuple, list)): - # [(tensor1a, tensor1b), - # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b) - return tuple( - array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list)) - else: - # [tensor1, tensor2] --> tensor - return array_ops.concat(tensor_list, axis=0) + with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", + [filter]): + filter = ops.convert_to_tensor(filter) + filter_height, filter_width, in_channels, out_channels = ( + filter.shape.as_list()) + + if out_channels % in_channels != 0: + raise ValueError("out_channels must be evenly divisible by in_channels.") + channel_multiplier = out_channels // in_channels + + results = [] + filter = array_ops.reshape(filter, [ + filter_height, filter_width, in_channels, in_channels, + channel_multiplier + ]) + for i in range(in_channels): + # Slice out output corresponding to the correct filter. + filter_slice = array_ops.reshape( + filter[:, :, i, i, :], + [filter_height, filter_width, 1, channel_multiplier]) + results.append(filter_slice) + + # Concat along out_channel. + return array_ops.concat(results, axis=-2, name="in_channels") + + +def maybe_tuple(obj): + if not isinstance(obj, list): + return obj + return tuple(obj) def num_conv_locations(input_shape, strides): """Returns the number of spatial locations a 2D Conv kernel is applied to. Args: - input_shape: list representing shape of inputs to the Conv layer. - strides: list representing strides for the Conv kernel. + input_shape: List of ints representing shape of inputs to + tf.nn.convolution(). + strides: List of ints representing strides along spatial dimensions as + passed in to tf.nn.convolution(). Returns: A scalar |T| denoting the number of spatial locations for the Conv layer. """ - return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + spatial_input_locations = np.prod(input_shape[1:-1]) + + if strides is None: + spatial_strides_divisor = 1 + else: + spatial_strides_divisor = np.prod(strides) + + return spatial_input_locations // spatial_strides_divisor + + +class InputOutputMultiTowerMultiUse(InputOutputMultiTower): + """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" + def __init__(self, num_uses=None, *args, **kwargs): + self._num_uses = num_uses + super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) -class FullyConnectedMultiIndepFB(KroneckerProductFB): + def _process_data(self, grads_list): + """Process temporal/multi-use data into the format used by the factors. + + This function takes inputs and grads_lists data and processes it into + one of the formats expected by the FisherFactor classes (depending on + the value of the global configuration variable TOWER_STRATEGY). + + It accepts the data in one of two initial formats. The first possible + format is where self._inputs is a list of list of Tensors. The first index + is tower, the second is use/time-step. grads_list, meanwhile, is a list + over sources of such lists of lists. + + The second possible data format is where self._inputs is a Tensor with + uses/times-steps folded into the batch dimension. i.e. it is a Tensor + of shape [num_uses * size_batch, ...] which represents a reshape of a + Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is + a list over sources of such Tensors. + + There are two possible formats which inputs and grads_list are transformed + into. + + If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing + a single tensor (represented as a PartitionedTensor object) with all of + the data from the towers, as well as the uses/time-steps, concatenated + together. In this tensor the leading dimension is the batch and + use/time-step dimensions folded together (with 'use' being the major of + these two, so that the tensors can be thought of as reshapes of ones of + shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a + tuple over sources of such tensors. + + If TOWER_STRATEGY is "separate" the inputs are formatted into lists of + tensors over towers. Each of these tensors has a similar format to + the tensor produced by the "concat" option, except that each contains + only the data from a single tower. grads_list is similarly formatted + into a tuple over sources of such tuples. + + Args: + grads_list: grads_list in its initial format (see above). + + Returns: + inputs: self._inputs transformed into the appropriate format (see + above). + grads_list: grads_list transformed into the appropriate format (see + above). + + Raises: + ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". + ValueError: If the given/initial format of self._inputs and grads_list + isn't recognized, or doesn't agree with self._num_uses. + """ + + inputs = self._inputs + + if isinstance(inputs[0], (list, tuple)): + num_uses = len(inputs[0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of inputs.") + else: + self._num_uses = num_uses + + # Check that all mini-batches/towers have the same number of uses + if not all(len(input_) == num_uses for input_ in inputs): + raise ValueError("Length of inputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + inputs = tuple(zip(*inputs)) + + # Flatten the two dimensions + inputs = nest.flatten(inputs) + + # Merge everything together into a PartitionedTensor. We package it in + # a singleton tuple since the factors will expect a list over towers + inputs = (utils.PartitionedTensor(inputs),) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + # Now we perform the analogous processing for grads_list + if isinstance(grads_list[0][0], (list, tuple)): + num_uses = len(grads_list[0][0]) + if self._num_uses is not None and self._num_uses != num_uses: + raise ValueError("num_uses argument doesn't match length of outputs, " + "or length of outputs is inconsistent with length of " + "inputs.") + else: + self._num_uses = num_uses + + if not all(len(grad) == num_uses for grads in grads_list + for grad in grads): + raise ValueError("Length of outputs argument is inconsistent across " + "towers.") + + if fisher_factors.TOWER_STRATEGY == "concat": + # Reverse the tower and use/time-step indices, so that use is now first, + # and towers is second + grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) + + # Flatten the two dimensions, leaving the leading dimension (source) + # intact + grads_list = tuple(nest.flatten(grads) for grads in grads_list) + + # Merge inner dimensions together into PartitionedTensors. We package + # them in a singleton tuple since the factors will expect a list over + # towers + grads_list = tuple((utils.PartitionedTensor(grads),) + for grads in grads_list) + + elif fisher_factors.TOWER_STRATEGY == "separate": + # Merge together the uses/time-step dimension into PartitionedTensors, + # but keep the leading dimension (towers) intact for the factors to + # process individually. + grads_list = tuple(tuple(utils.PartitionedTensor(grad) + for grad in grads) + for grads in grads_list) + + else: + raise ValueError("Global config variable TOWER_STRATEGY must be one of " + "'concat' or 'separate'.") + + if self._num_uses is None: + raise ValueError("You must supply a value for the num_uses argument if " + "the number of uses cannot be inferred from inputs or " + "outputs arguments (e.g. if they are both given in the " + "single Tensor format, instead of as lists of Tensors.") + + return inputs, grads_list + + +class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters. + + This class implements the "independence across time" approximation from the + following paper: + https://openreview.net/pdf?id=HyMTkQZAb """ - def __init__(self, layer_collection, inputs, outputs, has_bias=False): + def __init__(self, layer_collection, has_bias=False, num_uses=None): """Creates a FullyConnectedMultiIndepFB block. Args: layer_collection: LayerCollection instance. - inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, - inputs_size]. - outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, - outputs_size]. has_bias: bool. If True, estimates Fisher with respect to a bias parameter as well as the layer's parameters. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) """ - - assert len(inputs) == len(outputs) - # We need to make sure inputs and outputs are tuples and not lists so that - # they get hashed by layer_collection.make_or_get_factor properly. - self._inputs = tuple(inputs) - self._outputs = tuple(outputs) self._has_bias = has_bias - self._num_uses = len(inputs) - super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) - - @property - def num_registered_minibatches(self): - # TODO(b/69411207): Add support for registering additional minibatches. - return 1 + super(FullyConnectedMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.FullyConnectedMultiKF, - ((self._inputs,), self._has_bias)) + ((inputs,), self._num_uses, self._has_bias)) self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - damping = normalize_damping(damping, self._num_uses) - self._register_damped_input_and_output_inverses(damping) + self._setup_damping(damping, normalization=self._num_uses) @property def _renorm_coeff(self): - return self._num_uses + return float(self._num_uses) - def tensors_to_compute_grads(self): - return self._outputs - def num_inputs(self): - return len(self._inputs) +class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Similar to ConvKFCBasicFB except that this version supports multiple + uses/time-steps via a standard independence approximation. Similar to the + "independence across time" used in FullyConnectedMultiIndepFB but generalized + in the obvious way to conv layers. + """ + + def __init__(self, + layer_collection, + params, + padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, + num_uses=None): + """Creates a ConvKFCBasicMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [..spatial_filter_shape.., + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + padding: str. Padding method. + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with uses/time folded into the + batch dimension (instead of uses/time being a list dimension). + (Default: None) + """ + self._padding = padding + self._strides = maybe_tuple(strides) + self._dilation_rate = maybe_tuple(dilation_rate) + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._padding, self._strides, + self._dilation_rate, self._data_format, self._extract_patches_fn, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + self._setup_damping(damping, normalization= + (self._num_locations * self._num_uses)) + + @property + def _renorm_coeff(self): + return self._num_locations * self._num_uses + + +class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): + """K-FAC FisherBlock for embedding layers used multiple times in the graph. + + Similar to EmbeddingKFACFB except that this version supports multiple uses + of the parameter within a single model. These uses could correspond to time + steps in an RNN architecture, but they don't have to. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size, num_uses=None): + """Creates a EmbeddingKFACMultiIndepFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + num_uses: int or None. Number of uses of the layer in the model's graph. + Only required if the data is formatted with time folded into the batch + dimension (instead of time being a list dimension). (Default: None) + """ + self._vocab_size = vocab_size + + super(EmbeddingKFACMultiIndepFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of list of Tensors. grads_list[i][j][k] is the + gradient of the loss with respect to 'outputs' from source 'i', + tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape + [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + inputs, grads_list = self._process_data(grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.EmbeddingInputKroneckerFactor, + (inputs, self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._setup_damping(damping, normalization=self._num_uses) + + @property + def _renorm_coeff(self): + return float(self._num_uses) class SeriesFBApproximation(enum.IntEnum): @@ -836,34 +1469,35 @@ class SeriesFBApproximation(enum.IntEnum): option2 = 2 -class FullyConnectedSeriesFB(FisherBlock): +class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, + KroneckerProductFB): """FisherBlock for fully-connected layers that share parameters across time. - See the following preprint for details: + This class implements the "Option 1" and "Option 2" approximation from the + following paper: https://openreview.net/pdf?id=HyMTkQZAb See the end of the appendix of the paper for a pseudo-code of the - algorithm being implemented by multiply_inverse here. Note that we are + algorithm being implemented by multiply_matpower here. Note that we are using pre-computed versions of certain matrix-matrix products to speed things up. This is explicitly explained wherever it is done. """ def __init__(self, layer_collection, - inputs, - outputs, has_bias=False, + num_uses=None, option=SeriesFBApproximation.option2): """Constructs a new `FullyConnectedSeriesFB`. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. - inputs: List of tensors of shape [batch_size, input_size]. - Inputs to the layer. - outputs: List of tensors of shape [batch_size, input_size]. - Outputs of the layer (before activations). has_bias: Whether the layer includes a bias parameter. + num_uses: int or None. Number of time-steps over which the layer + is used. Only required if the data is formatted with time folded into + the batch dimension (instead of time being a list dimension). + (Default: None) option: A `SeriesFBApproximation` specifying the simplifying assumption to be used in this block. `option1` approximates the cross-covariance over time as a symmetric matrix, while `option2` makes @@ -871,48 +1505,58 @@ class FullyConnectedSeriesFB(FisherBlock): 3.5 of the paper for more details. """ - assert len(inputs) == len(outputs) - # We need to make sure inputs and outputs are tuples and not lists so that - # they get hashed by layer_collection.make_or_get_factor properly. - self._inputs = tuple(inputs) - self._outputs = tuple(outputs) self._has_bias = has_bias - self._num_timesteps = len(inputs) self._option = option - super(FullyConnectedSeriesFB, self).__init__(layer_collection) + super(FullyConnectedSeriesFB, self).__init__( + layer_collection=layer_collection, + num_uses=num_uses) + + @property + def _num_timesteps(self): + return self._num_uses @property - def num_registered_minibatches(self): - # TODO(b/69411207): Add support for registering additional minibatches. - return 1 + def _renorm_coeff(self): + # This should no longer be used since the multiply_X functions from the base + # class have been overridden + assert False def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + fisher_factors.FullyConnectedMultiKF, + ((inputs,), self._num_uses, self._has_bias)) + self._input_factor.register_cov_dt1() self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list,)) + fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) + self._output_factor.register_cov_dt1() + + self._setup_damping(damping, normalization=self._num_uses) - damping = normalize_damping(damping, self._num_timesteps) - self._damping_input, self._damping_output = compute_pi_adjusted_damping( - self._input_factor.get_cov(), - self._output_factor.get_cov(), - damping**0.5) + def register_matpower(self, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") if self._option == SeriesFBApproximation.option1: - self._input_factor.register_option1quants(self._damping_input) - self._output_factor.register_option1quants(self._damping_output) + self._input_factor.register_option1quants(self._input_damping_func) + self._output_factor.register_option1quants(self._output_damping_func) elif self._option == SeriesFBApproximation.option2: - self._input_factor.register_option2quants(self._damping_input) - self._output_factor.register_option2quants(self._damping_output) + self._input_factor.register_option2quants(self._input_damping_func) + self._output_factor.register_option2quants(self._output_damping_func) else: raise ValueError( "Unrecognized FullyConnectedSeriesFB approximation: {}".format( self._option)) - def multiply_inverse(self, vector): + def multiply_matpower(self, vector, exp): + if exp != -1: + raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" + "multiplications.") + # pylint: disable=invalid-name Z = utils.layer_params_to_mat2d(vector) @@ -923,9 +1567,11 @@ class FullyConnectedSeriesFB(FisherBlock): if self._option == SeriesFBApproximation.option1: - # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. - L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) - L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) + L_A, psi_A = self._input_factor.get_option1quants( + self._input_damping_func) + L_G, psi_G = self._output_factor.get_option1quants( + self._output_damping_func) def gamma(x): # We are assuming that each case has the same number of time-steps. @@ -935,60 +1581,61 @@ class FullyConnectedSeriesFB(FisherBlock): T = self._num_timesteps return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) # Even though Y is Z-independent we are recomputing it from the psi's # each since Y depends on both A and G quantities, and it is relatively # cheap to compute. Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - # Z = L_G^T * Z * L_A + # \\(Z = L_G^T * Z * L_A\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = U_G^T * Z * U_A + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = U_G^T * Z * U_A\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - # Z = Z .* Y + # \\(Z = Z .* Y\\) Z *= Y - # Z = L_G * Z * L_A^T + # \\(Z = L_G * Z * L_A^T\\) # This is equivalent to the following computation from the original # pseudo-code: - # Z = U_G * Z * U_A^T - # Z = G0^(-1/2) * Z * A0^(-1/2) + # \\(Z = U_G * Z * U_A^T\\) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) elif self._option == SeriesFBApproximation.option2: - # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), - # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. - P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), + # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) + P_A, K_A, mu_A = self._input_factor.get_option2quants( + self._input_damping_func) P_G, K_G, mu_G = self._output_factor.get_option2quants( - self._damping_output) + self._output_damping_func) # Our approach differs superficially from the pseudo-code in the paper # in order to reduce the total number of matrix-matrix multiplies. # In particular, the first three computations in the pseudo code are - # Z = G0^(-1/2) * Z * A0^(-1/2) - # Z = Z - hPsi_G^T * Z * hPsi_A - # Z = E_G^T * Z * E_A - # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that - # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) + # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) + # \\(Z = E_G^T * Z * E_A\\) + # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that + # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) # the entire computation can be written as - # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A - # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) - # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A - # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A - # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A - # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) + # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) + # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) + # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) + # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) + # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) # This final expression is computed by the following two lines: - # Z = Z - P_G * Z * P_A^T + # \\(Z = Z - P_G * Z * P_A^T\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # Z = K_G^T * Z * K_A + # \\(Z = K_G^T * Z * K_A\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) # Be careful with the outer product. We don't want to accidentally # make it an inner-product instead. tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A @@ -999,13 +1646,13 @@ class FullyConnectedSeriesFB(FisherBlock): # We now perform the transpose/reverse version of the operations # derived above, whose derivation from the original pseudo-code is # analgous. - # Z = K_G * Z * K_A^T + # \\(Z = K_G * Z * K_A^T\\) Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - # Z = Z - P_G^T * Z * P_A + # \\(Z = Z - P_G^T * Z * P_A\\) Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - # Z = normalize (1/E[T]) * Z + # \\(Z = normalize (1/E[T]) * Z\\) # Note that this normalization is done because we compute the statistics # by averaging, not summing, over time. (And the gradient is presumably # summed over time, not averaged, and thus their scales are different.) @@ -1017,12 +1664,3 @@ class FullyConnectedSeriesFB(FisherBlock): return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name - - def multiply(self, vector): - raise NotImplementedError - - def tensors_to_compute_grads(self): - return self._outputs - - def num_inputs(self): - return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index ac396309206fe09af65c2b70840a513fb25b579b..c04cf727fa958160d61c7a3638ec65f6c93c2f24 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -29,6 +29,7 @@ _allowed_symbols = [ 'NaiveDiagonalFB', 'FullyConnectedDiagonalFB', 'KroneckerProductFB', + 'EmbeddingKFACFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', 'ConvDiagonalFB', @@ -36,7 +37,9 @@ _allowed_symbols = [ 'compute_pi_tracenorm', 'compute_pi_adjusted_damping', 'num_conv_locations', - 'normalize_damping' + 'normalize_damping', + 'LEFT_MULTIPLY', + 'RIGHT_MULTIPLY', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index bcba18ae147c6ceca50bc9a2a17e01fc201d88c1..0d40d265a1727075d0ba721b0d9a756c38269a96 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -25,17 +25,19 @@ import numpy as np import six from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import moving_averages +from tensorflow.python.util import nest + # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). @@ -53,36 +55,25 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 -# Colocate the covariance ops and variables with the input tensors for each -# factor. -COLOCATE_COV_OPS_WITH_INPUTS = True - - -@contextlib.contextmanager -def maybe_colocate_with(op): - """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" - if COLOCATE_COV_OPS_WITH_INPUTS: - if isinstance(op, (list, tuple)): - with tf_ops.colocate_with(op[0]): - yield - else: - with tf_ops.colocate_with(op): - yield - else: - yield +# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data +# passed to the factors from the blocks will be concatenated across towers +# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over +# towers will be passed in, and the factors will iterate over this and do the +# cov computations separately for each one, averaging the results together. +TOWER_STRATEGY = "concat" def set_global_constants(init_covariances_at_zero=None, zero_debias=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, - colocate_cov_ops_with_inputs=None): + tower_strategy=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD - global COLOCATE_COV_OPS_WITH_INPUTS + global TOWER_STRATEGY if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero @@ -92,8 +83,8 @@ def set_global_constants(init_covariances_at_zero=None, EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold - if colocate_cov_ops_with_inputs is not None: - COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs + if tower_strategy is not None: + TOWER_STRATEGY = tower_strategy def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument @@ -112,52 +103,13 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def extract_image_patches(image, ksizes, strides, padding, name=None): - """Extracts image patches for an N-dimensional convolution. - - This function is a compatibility wrapper over tf.extract_image_patches(), as - ExtractImagePatches isn't yet implemented in XLA. - - Args: - image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images. - All dimensions except 'batch' must be defined. - ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each - dimension. - strides: [stride_x, stride_y, ...]. Spatial stride for filter in each - dimension. - padding: str. "VALID" or "SAME". - name: str or None. name of Op. - - Returns: - result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels]. - Contains image patches to which conv kernel would be applied for each - output location. [out_x, out_y, ...] depends on padding. - """ - if not utils.on_tpu(): - return array_ops.extract_image_patches( - image, - ksizes=([1] + list(ksizes) + [1]), - strides=([1] + list(strides) + [1]), - rates=[1, 1, 1, 1], - padding=padding, - name=name) - - with tf_ops.name_scope(name, "extract_image_patches", - [image, ksizes, strides, padding]): - batch = image.shape.as_list()[0] - in_channels = image.shape.as_list()[-1] - - # Map each input feature to a location in the output. - out_channels = np.prod(ksizes) * in_channels - filters = linalg_ops.eye(out_channels), - filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels]) - - result = nn.convolution(image, filters, padding, strides=strides) - out_spatial = result.shape.as_list()[1:-1] - result = array_ops.reshape( - result, [batch or -1] + out_spatial + ksizes + [in_channels]) - - return result +@contextlib.contextmanager +def place_on_device(device): + if device is not None and len(device): + with tf_ops.device(device): + yield + else: + yield def compute_cov(tensor, tensor_right=None, normalizer=None): @@ -229,7 +181,9 @@ def scope_string_from_params(params): name_parts = [] for param in params: - if isinstance(param, (tuple, list)): + if param is None: + name_parts.append("None") + elif isinstance(param, (tuple, list)): if all([isinstance(p, int) for p in param]): name_parts.append("-".join([str(p) for p in param])) else: @@ -238,6 +192,8 @@ def scope_string_from_params(params): name_parts.append(str(param)) elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) + elif isinstance(param, utils.PartitionedTensor): + name_parts.append(scope_string_from_name(param.tensors)) else: raise ValueError("Encountered an unsupported param type {}".format( type(param))) @@ -255,33 +211,64 @@ def scalar_or_tensor_to_string(val): return repr(val) if np.isscalar(val) else scope_string_from_name(val) +def list_to_string(lst): + return "_".join(val if isinstance(val, six.string_types) + else scalar_or_tensor_to_string(val) for val in lst) + + +def graph_func_to_id(func): + """Returns a hashable object that represents func's computation.""" + # TODO(b/74201126): replace with Topohash of func's output + return func.func_id + + +def graph_func_to_string(func): + # TODO(b/74201126): replace with Topohash of func's output + return list_to_string(func.func_id) + + @six.add_metaclass(abc.ABCMeta) class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. - Note that for blocks that aren't based on approximations, a 'factor' can - be the entire block itself, as is the case for the diagonal and full - representations. + A FisherFactor represents part of an approximate Fisher Information matrix. + For example, one approximation to the Fisher uses the Kronecker product of two + FisherFactors A and B, F = kron(A, B). FisherFactors are composed with + FisherBlocks to construct a block-diagonal approximation to the full Fisher. + + FisherFactors are backed by a single, non-trainable variable that is updated + by running FisherFactor.make_covariance_update_op(). The shape and type of + this variable is implementation specific. - Subclasses must implement the _compute_new_cov method, and the _var_scope - and _cov_shape properties. + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. """ def __init__(self): - self.instantiate_covariance() + self._cov = None @abc.abstractproperty def _var_scope(self): + """Variable scope for this FisherFactor instance. + + Returns: + string that unique identifies this FisherFactor instance. + """ pass + @property + def name(self): + return self._var_scope + @abc.abstractproperty def _cov_shape(self): - """The shape of the cov matrix.""" + """The shape of the variable backing this FisherFactor.""" pass @abc.abstractproperty def _num_sources(self): - """The number of things to sum over when computing cov. + """The number of things to sum over when updating covariance variable. The default make_covariance_update_op function will call _compute_new_cov with indices ranging from 0 to _num_sources-1. The typical situation is @@ -291,16 +278,23 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _num_towers(self): + pass + @abc.abstractproperty def _dtype(self): + """dtype for variable backing this factor.""" pass @property def _cov_initializer(self): + """Function for initializing covariance variable.""" return covariance_initializer - def instantiate_covariance(self): - """Instantiates the covariance Variable as the instance member _cov.""" + def instantiate_cov_variables(self): + """Makes the internal cov variable(s).""" + assert self._cov is None with variable_scope.variable_scope(self._var_scope): self._cov = variable_scope.get_variable( "cov", @@ -310,7 +304,18 @@ class FisherFactor(object): dtype=self._dtype) @abc.abstractmethod - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): + """Computes minibatch-estimated covariance for a single source. + + Args: + source: int in [0, self._num_sources). Which source to use when computing + the cov update. + tower: int in [0, self._num_towers). Which tower to use when computing + the cov update. + + Returns: + Tensor of same shape as self.get_cov_var(). + """ pass def make_covariance_update_op(self, ema_decay): @@ -321,36 +326,115 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov_contribs = tuple(self._compute_new_cov(idx) - for idx in range(self._num_sources)) - # This gets the job done but we might want a better solution in the future. - # In particular, we could have a separate way of specifying where the - # the cov variables finally end up, independent of where their various - # contributions are computed. Right now these are the same thing, but in - # the future we might want to perform the cov computations on each tower, - # so that each tower will be considered a "source" (allowing us to reuse - # the existing "source" code for this). - with maybe_colocate_with(new_cov_contribs[0]): - new_cov = math_ops.add_n(new_cov_contribs) - # Synchronize value across all TPU cores. - if utils.on_tpu(): - new_cov = utils.cross_replica_mean(new_cov) - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + new_cov_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + device = (self._get_data_device(tower) + if TOWER_STRATEGY == "separate" else None) + with place_on_device(device): + new_cov_contribs.append(self._compute_new_cov(source, tower)) + + new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) + + # Compute average of 'new_cov' across all TPU cores. On a TPU, each + # instance of 'new_cov' will be based on a different minibatch. This ensures + # that by the end of assign_moving_average(), all TPU cores see the same + # value for self._cov. + # + # Other implementations of make_covariance_update_op() that accumulate + # statistics in other variables should mimic this behavior. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + + @abc.abstractmethod + def _get_data_device(self, tower): + pass + + @abc.abstractmethod + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + pass @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" pass + @abc.abstractmethod def get_cov(self): + """Get full covariance matrix. + + Returns: + Tensor of shape [n, n]. Represents all parameter-parameter correlations + captured by this FisherFactor. + """ + pass + + def get_cov_var(self): + """Get variable backing this FisherFactor. + + May or may not be the same as self.get_cov() + + Returns: + Variable of shape self._cov_shape. + """ return self._cov + @abc.abstractmethod + def left_multiply_matpower(self, x, exp, damping_func): + """Left multiplies 'x' by matrix power of this factor (w/ damping applied). + + This calculation is essentially: + (C + damping * I)**exp * x + where * is matrix-multiplication, ** is matrix power, I is the identity + matrix, and C is the matrix represented by this factor. + + x can represent either a matrix or a vector. For some factors, 'x' might + represent a vector but actually be stored as a 2D matrix for convenience. + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + exp: float. The matrix exponent to use. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + + Returns: + Tensor of same shape as 'x' representing the result of the multiplication. + """ + pass + + @abc.abstractmethod + def right_multiply_matpower(self, x, exp, damping_func): + """Right multiplies 'x' by matrix power of this factor (w/ damping applied). + + This calculation is essentially: + x * (C + damping * I)**exp + where * is matrix-multiplication, ** is matrix power, I is the identity + matrix, and C is the matrix represented by this factor. + + Unlike left_multiply_matpower, x will always be a matrix. + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + exp: float. The matrix exponent to use. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + + Returns: + Tensor of same shape as 'x' representing the result of the multiplication. + """ + pass + class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses, powers, etc of _cov. + """Base class for FisherFactors that maintain inverses explicitly. - Assumes that the _cov property is a square PSD matrix. + This class explicitly calculates and stores inverses of covariance matrices + provided by the underlying FisherFactor implementation. It is assumed that + vectors can be represented as 2-D matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -364,47 +448,52 @@ class InverseProvidingFactor(FisherFactor): # the latter. def __init__(self): - self._inverses_by_damping = {} - self._matpower_by_exp_and_damping = {} + self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } + self._matpower_registrations = set() # { (float, hashable) } self._eigendecomp = None + self._damping_funcs_by_id = {} # {hashable: lambda} super(InverseProvidingFactor, self).__init__() - def register_damped_inverse(self, damping): - """Registers a damped inverse needed by a FisherBlock. + def _register_damping(self, damping_func): + damping_id = graph_func_to_id(damping_func) + if damping_id not in self._damping_funcs_by_id: + self._damping_funcs_by_id[damping_id] = damping_func + return damping_id - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_inverse. + def register_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + self.register_matpower(-1, damping_func) - Args: - damping: The damping value (float or Tensor) for this factor. - """ - if damping not in self._inverses_by_damping: - damping_string = scalar_or_tensor_to_string(damping) - with variable_scope.variable_scope(self._var_scope): - inv = variable_scope.get_variable( - "inv_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - self._inverses_by_damping[damping] = inv - - def register_matpower(self, exp, damping): - """Registers a matrix power needed by a FisherBlock. + def register_matpower(self, exp, damping_func): + """Registers a matrix power to be maintained and served on demand. This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower. Args: - exp: The exponent (float or Tensor) to raise the matrix to. - damping: The damping value (float or Tensor). + exp: float. The exponent to use in the matrix power. + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). """ - if (exp, damping) not in self._matpower_by_exp_and_damping: + if exp == 1.0: + # We don't register these. The user shouldn't even be calling this + # function with exp = 1.0. + return + + damping_id = self._register_damping(damping_func) + + if (exp, damping_id) not in self._matpower_registrations: + self._matpower_registrations.add((exp, damping_id)) + + def instantiate_inv_variables(self): + """Makes the internal "inverse" variable(s).""" + + for (exp, damping_id) in self._matpower_registrations: exp_string = scalar_or_tensor_to_string(exp) - damping_string = scalar_or_tensor_to_string(damping) + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) with variable_scope.variable_scope(self._var_scope): matpower = variable_scope.get_variable( "matpower_exp{}_damp{}".format(exp_string, damping_string), @@ -412,34 +501,35 @@ class InverseProvidingFactor(FisherFactor): shape=self._cov_shape, trainable=False, dtype=self._dtype) - self._matpower_by_exp_and_damping[(exp, damping)] = matpower + assert (exp, damping_id) not in self._matpower_by_exp_and_damping + self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] - # We do this to ensure that we don't reuse the eigendecomp from old calls - # to make_inverse_update_ops that may be placed on different devices. This - # can happen is the user has both a permanent and lazily constructed - # version of the inverse ops (and only uses one of them). - self.reset_eigendecomp() + num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping + if exp == -1) + + num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses + + other_matrix_power_registered = num_other_matpower >= 1 - num_inverses = len(self._inverses_by_damping) - matrix_power_registered = bool(self._matpower_by_exp_and_damping) use_eig = ( - self._eigendecomp or matrix_power_registered or + self._eigendecomp or other_matrix_power_registered or num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + # We precompute these so we don't need to evaluate them multiple times (for + # each matrix power that uses them) + damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]() + for damping_id in self._damping_funcs_by_id} + if use_eig: eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence - for damping, inv in self._inverses_by_damping.items(): - ops.append( - inv.assign( - math_ops.matmul(eigenvectors / (eigenvalues + damping), - array_ops.transpose(eigenvectors)))) - - for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + damping = damping_value_by_id[damping_id] ops.append( matpower.assign( math_ops.matmul(eigenvectors * @@ -448,28 +538,31 @@ class InverseProvidingFactor(FisherFactor): # These ops share computation and should be run on a single device. ops = [control_flow_ops.group(*ops)] else: - for damping, inv in self._inverses_by_damping.items(): - ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) + for (exp, damping_id), matpower in ( + self._matpower_by_exp_and_damping.items()): + assert exp == -1 + damping = damping_value_by_id[damping_id] + ops.append(matpower.assign(utils.posdef_inv(self._cov, damping))) + self._eigendecomp = False return ops - def get_damped_inverse(self, damping): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - return self._inverses_by_damping[damping] + def get_inverse(self, damping_func): + # Just for backwards compatibility of some old code and tests + damping_id = graph_func_to_id(damping_func) + return self._matpower_by_exp_and_damping[(-1, damping_id)] - def get_matpower(self, exp, damping): + def get_matpower(self, exp, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). - return self._matpower_by_exp_and_damping[(exp, damping)] + damping_id = graph_func_to_id(damping_func) + return self._matpower_by_exp_and_damping[(exp, damping_id)] def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" - # Unlike get_inverse and get_matpower this doesn't retrieve a stored - # variable, but instead always computes a fresh version from the current - # value of get_cov(). + # Unlike get_matpower this doesn't retrieve a stored variable, but instead + # always computes a fresh version from the current value of get_cov(). if not self._eigendecomp: eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) @@ -482,8 +575,42 @@ class InverseProvidingFactor(FisherFactor): return self._eigendecomp - def reset_eigendecomp(self): - self._eigendecomp = None + def get_cov(self): + # Variable contains full covariance matrix. + return self.get_cov_var() + + def left_multiply_matpower(self, x, exp, damping_func): + if isinstance(x, tf_ops.IndexedSlices): + raise ValueError("Left-multiply not yet supported for IndexedSlices.") + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + if exp == 1: + return math_ops.matmul(self.get_cov(), x) + damping_func() * x + + return math_ops.matmul(self.get_matpower(exp, damping_func), x) + + def right_multiply_matpower(self, x, exp, damping_func): + if isinstance(x, tf_ops.IndexedSlices): + if exp == 1: + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping_func() * array_ops.eye(n) + return utils.matmul_sparse_dense(x, damped_cov) + + return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func)) + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + if exp == 1: + return math_ops.matmul(x, self.get_cov()) + damping_func() * x + + return math_ops.matmul(x, self.get_matpower(exp, damping_func)) class FullFactor(InverseProvidingFactor): @@ -503,7 +630,7 @@ class FullFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_full/" + scope_string_from_params( + return "ff_full_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property @@ -516,23 +643,36 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): + def _compute_new_cov(self, source, tower): + assert tower == 0 + # This will be a very basic rank 1 estimate - with maybe_colocate_with(self._params_grads[idx]): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) - return ((params_grads_flat * array_ops.transpose( - params_grads_flat)) / math_ops.cast(self._batch_size, - params_grads_flat.dtype)) + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None class DiagonalFactor(FisherFactor): - """A base class for FisherFactors that use diagonal approximations.""" + """A base class for FisherFactors that use diagonal approximations. + + A DiagonalFactor's covariance variable can be of any shape, but must contain + exactly one entry per parameter. + """ def __init__(self): + self._damping_funcs_by_id = {} # { hashable: lambda } super(DiagonalFactor, self).__init__() @property @@ -542,6 +682,32 @@ class DiagonalFactor(FisherFactor): def make_inverse_update_ops(self): return [] + def instantiate_inv_variables(self): + pass + + def get_cov(self): + # self.get_cov() could be any shape, but it must have one entry per + # parameter. Flatten it into a vector. + cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) + return array_ops.diag(cov_diag_vec) + + def left_multiply_matpower(self, x, exp, damping_func): + matpower = (self.get_cov_var() + damping_func())**exp + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x) + + if x.shape != matpower.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, matpower)) + return matpower * x + + def right_multiply_matpower(self, x, exp, damping_func): + raise NotImplementedError("Only left-multiply is currently supported.") + + def register_matpower(self, exp, damping_func): + pass + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -553,6 +719,14 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, batch_size): + """Initializes NaiveDiagonalFactor instance. + + Args: + params_grads: Sequence of Tensors, each with same shape as parameters this + FisherFactor corresponds to. For example, the gradient of the loss with + respect to parameters. + batch_size: int or 0-D Tensor. Size + """ self._params_grads = tuple(utils.ensure_sequence(params_grad) for params_grad in params_grads) self._batch_size = batch_size @@ -560,28 +734,123 @@ class NaiveDiagonalFactor(DiagonalFactor): @property def _var_scope(self): - return "ff_naivediag/" + scope_string_from_params( + return "ff_naivediag_" + scope_string_from_params( [self._params_grads, self._batch_size]) @property def _cov_shape(self): size = sum(param_grad.shape.num_elements() for param_grad in self._params_grads[0]) - return (size, 1) + return [size, 1] @property def _num_sources(self): return len(self._params_grads) + @property + def _num_towers(self): + return 1 + @property def _dtype(self): return self._params_grads[0][0].dtype - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._params_grads[idx]): - params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) - return (math_ops.square(params_grads_flat) / math_ops.cast( - self._batch_size, params_grads_flat.dtype)) + def _compute_new_cov(self, source, tower): + assert tower == 0 + + params_grads_flat = utils.tensors_to_column(self._params_grads[source]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) + + def _get_data_device(self, tower): + return None + + +class EmbeddingInputKroneckerFactor(DiagonalFactor): + r"""FisherFactor for input to an embedding layer. + + Given input_ids = [batch_size, input_size] representing indices into an + [vocab_size, embedding_size] embedding matrix, approximate input covariance by + a diagonal matrix, + + Cov(input_ids, input_ids) = + (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). + + where n_hot() constructs an n-hot binary vector and diag() constructs a + diagonal matrix of size [vocab_size, vocab_size]. + """ + + def __init__(self, input_ids, vocab_size, dtype=None): + """Instantiate EmbeddingInputKroneckerFactor. + + Args: + input_ids: List of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. List index is tower. + vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. + dtype: dtype for covariance statistics. Must be a floating point type. + Defaults to float32. + """ + self._input_ids = input_ids + self._vocab_size = vocab_size + self._cov_dtype = dtype or dtypes.float32 + + super(EmbeddingInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) + + @property + def _cov_shape(self): + return [self._vocab_size] + + @property + def _num_sources(self): + return 1 + + @property + def _num_towers(self): + return len(self._input_ids) + + @property + def _dtype(self): + return self._cov_dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + input_ids = self._input_ids[tower] + + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] + + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] + + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _get_data_device(self, tower): + return self._input_ids[tower].device class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -602,57 +871,75 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): """Instantiate FullyConnectedDiagonalFactor. Args: - inputs: Tensor of shape [batch_size, input_size]. Inputs to fully - connected layer. - outputs_grads: List of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to layer's preactivations. + inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this + layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, output_size], + which are the gradients of the loss with respect to the layer's + outputs. First index is source, second is tower. + has_bias: bool. If True, append '1' to each input. """ self._inputs = inputs self._has_bias = has_bias self._outputs_grads = outputs_grads - self._batch_size = array_ops.shape(inputs)[0] self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_diagfc/" + scope_string_from_params( - (self._inputs,) + tuple(self._outputs_grads)) + return "ff_diagfc_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): - return [self._inputs.shape[1] + self._has_bias, - self._outputs_grads[0].shape[1]] + input_size = self._inputs[0].shape[1] + self._has_bias + output_size = self._outputs_grads[0][0].shape[1] + return [input_size, output_size] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._outputs_grads[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + self._squared_inputs = [] + for tower in range(self._num_towers): + inputs = self._inputs[tower] + + with place_on_device(self._get_data_device(tower)): + if self._has_bias: + inputs = append_homog(inputs) + self._squared_inputs.append(math_ops.square(inputs)) + + return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( + ema_decay) + + def _compute_new_cov(self, source, tower): + batch_size = array_ops.shape(self._squared_inputs[tower])[0] + outputs_grad = self._outputs_grads[source][tower] - def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with maybe_colocate_with(self._outputs_grads[idx]): - # We only need to compute squared_inputs once - if self._squared_inputs is None: - inputs = self._inputs - if self._has_bias: - inputs = append_homog(self._inputs) - self._squared_inputs = math_ops.square(inputs) + new_cov = math_ops.matmul( + self._squared_inputs[tower], + math_ops.square(outputs_grad), + transpose_a=True) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + return new_cov - new_cov = math_ops.matmul( - self._squared_inputs, - math_ops.square(self._outputs_grads[idx]), - transpose_a=True) - new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) - return new_cov + def _get_data_device(self, tower): + return self._inputs[tower].device class ConvDiagonalFactor(DiagonalFactor): @@ -664,36 +951,67 @@ class ConvDiagonalFactor(DiagonalFactor): filter_shape, strides, padding, + data_format=None, + dilations=None, has_bias=False): """Creates a ConvDiagonalFactor object. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. - Input activations to this layer. - outputs_grads: Tensor of shape [batch_size, height, width, out_channels]. - Per-example gradients to the loss with respect to the layer's output - preactivations. + inputs: List of Tensors of shape [batch_size, height, width, in_channels]. + Input activations to this layer. List index is towers. + outputs_grads: List of Tensors, each of shape [batch_size, + height, width, out_channels], which are the gradients of the loss + with respect to the layer's outputs. First index is source, second + index is tower. filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer. strides: The stride size in this layer (1-D Tensor of length 4). padding: The padding in this layer (1-D of Tensor length 4). + data_format: None or str. Format of conv2d inputs. + dilations: None or tuple of 4 ints. has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + + Raises: + ValueError: If inputs, output_grads, and filter_shape do not agree on + in_channels or out_channels. + ValueError: If strides, dilations are not length-4 lists of ints. + ValueError: If data_format does not put channel last. """ + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + if any(input_.shape.ndims != 4 for input_ in inputs): + raise ValueError("inputs must be a list of 4-D Tensors.") + if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): + raise ValueError("inputs and filter_shape must agree on in_channels.") + for i, outputs_grad in enumerate(outputs_grads): + if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): + raise ValueError("outputs[%d] must be 4-D Tensor." % i) + if any(output_grad.shape.as_list()[-1] != filter_shape[-1] + for output_grad in outputs_grad): + raise ValueError( + "outputs[%d] and filter_shape must agree on out_channels." % i) + if len(strides) != 4: + raise ValueError("strides must be length-4 list of ints.") + if dilations is not None and len(dilations) != 4: + raise ValueError("dilations must be length-4 list of ints.") + self._inputs = inputs + self._outputs_grads = outputs_grads self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._data_format = data_format + self._dilations = dilations self._has_bias = has_bias - self._outputs_grads = outputs_grads self._patches = None super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_convdiag/" + scope_string_from_name( - (self._inputs,) + tuple(self._outputs_grads)) + return "ff_convdiag_" + scope_string_from_params( + tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) @property def _cov_shape(self): @@ -707,42 +1025,50 @@ class ConvDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._inputs) + @property def _dtype(self): - return self._outputs_grads[0].dtype + return self._inputs[0].dtype def make_covariance_update_op(self, ema_decay): - with maybe_colocate_with(self._inputs): - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - patches = extract_image_patches( - self._inputs, - ksizes=[filter_height, filter_width], - strides=self._strides[1:-1], - padding=self._padding) + filter_height, filter_width, _, _ = self._filter_shape - if self._has_bias: - patches = append_homog(patches) + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + if self._dilations is None: + rates = (1, 1, 1, 1) + else: + rates = tuple(self._dilations) + + self._patches = [] + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + patches = array_ops.extract_image_patches( + self._inputs[tower], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=rates, + padding=self._padding) - self._patches = patches + if self._has_bias: + patches = append_homog(patches) - op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + self._patches.append(patches) - self._patches = None + return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) - return op + def _compute_new_cov(self, source, tower): + patches = self._patches[tower] + batch_size = array_ops.shape(patches)[0] + outputs_grad = self._outputs_grads[source][tower] - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - outputs_grad = self._outputs_grads[idx] - batch_size = array_ops.shape(self._patches)[0] + new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - - return new_cov + return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". @@ -752,6 +1078,9 @@ class ConvDiagonalFactor(DiagonalFactor): outputs_grad) return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + def _get_data_device(self, tower): + return self._inputs[tower].device + class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. @@ -763,8 +1092,9 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Instantiate FullyConnectedKroneckerFactor. Args: - tensors: List of Tensors of shape [batch_size, n]. Represents either a - layer's inputs or its output's gradients. + tensors: List of list of Tensors, each of shape [batch_size, n]. The + Tensors are typically either a layer's inputs or its output's gradients. + The first list index is source, the second is tower. has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of @@ -775,28 +1105,34 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): @property def _var_scope(self): - return "ff_fckron/" + scope_string_from_params( - [self._tensors, self._has_bias]) + return "ff_fckron_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + (self._has_bias,)) @property def _cov_shape(self): - size = self._tensors[0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size, size] @property def _num_sources(self): return len(self._tensors) + @property + def _num_towers(self): + return len(self._tensors[0]) + @property def _dtype(self): - return self._tensors[0].dtype + return self._tensors[0][0].dtype + + def _compute_new_cov(self, source, tower): + tensor = self._tensors[source][tower] + if self._has_bias: + tensor = append_homog(tensor) + return compute_cov(tensor) - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._tensors[idx]): - tensor = self._tensors[idx] - if self._has_bias: - tensor = append_homog(tensor) - return compute_cov(tensor) + def _get_data_device(self, tower): + return self._tensors[0][tower].device class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -812,39 +1148,55 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def __init__(self, inputs, filter_shape, - strides, padding, + strides=None, + dilation_rate=None, + data_format=None, + extract_patches_fn=None, has_bias=False): """Initializes ConvInputKroneckerFactor. Args: - inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs - to layer. - filter_shape: 1-D Tensor of length 4. Contains [kernel_height, - kernel_width, in_channels, out_channels]. - strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride, - width_stride, in_channel_stride]. + inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., + in_channels]. Inputs to layer. List index is tower. + filter_shape: List of ints. Contains [..spatial_filter_size.., + in_channels, out_channels]. Shape of convolution kernel. padding: str. Padding method for layer. "SAME" or "VALID". + strides: List of ints or None. Contains [..spatial_filter_strides..] if + 'extract_patches_fn' is compatible with tf.nn.convolution(), else + [1, ..spatial_filter_strides, 1]. + dilation_rate: List of ints or None. Rate for dilation along each spatial + dimension if 'extract_patches_fn' is compatible with + tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. + data_format: str or None. Format of input data. + extract_patches_fn: str or None. Name of function that extracts image + patches. One of "extract_convolution_patches", "extract_image_patches", + "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. """ + self._inputs = inputs self._filter_shape = filter_shape self._strides = strides self._padding = padding + self._dilation_rate = dilation_rate + self._data_format = data_format + self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias - self._inputs = inputs + super(ConvInputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convinkron/" + scope_string_from_params([ - self._inputs, self._filter_shape, self._strides, self._padding, - self._has_bias - ]) + return "ff_convinkron_" + scope_string_from_params( + tuple(self._inputs) + + tuple((self._filter_shape, self._strides, self._padding, + self._dilation_rate, self._data_format, self._has_bias))) @property def _cov_shape(self): - filter_height, filter_width, in_channels, _ = self._filter_shape - size = filter_height * filter_width * in_channels + self._has_bias + spatial_filter_shape = self._filter_shape[0:-2] + in_channels = self._filter_shape[-2] + size = np.prod(spatial_filter_shape) * in_channels + self._has_bias return [size, size] @property @@ -852,43 +1204,77 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): return 1 @property - def _dtype(self): - return self._inputs.dtype - - def _compute_new_cov(self, idx=0): - if idx != 0: - raise ValueError("ConvInputKroneckerFactor only supports idx = 0") + def _num_towers(self): + return len(self._inputs) - with maybe_colocate_with(self._inputs): - filter_height, filter_width, in_channels, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms of - # memory use. - patches = extract_image_patches( - self._inputs, - ksizes=[filter_height, filter_width], - strides=self._strides[1:-1], + @property + def _dtype(self): + return self._inputs[0].dtype + + def _compute_new_cov(self, source, tower): + assert source == 0 + + inputs = self._inputs[tower] + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + if self._extract_patches_fn in [None, "extract_convolution_patches"]: + patches = utils.extract_convolution_patches( + inputs, + self._filter_shape, + padding=self._padding, + strides=self._strides, + dilation_rate=self._dilation_rate, + data_format=self._data_format) + + elif self._extract_patches_fn == "extract_image_patches": + assert inputs.shape.ndims == 4 + assert len(self._filter_shape) == 4 + assert len(self._strides) == 4, self._strides + if self._dilation_rate is None: + rates = [1, 1, 1, 1] + else: + rates = self._dilation_rate + assert len(rates) == 4 + assert rates[0] == rates[-1] == 1 + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1] + list(self._filter_shape[0:-2]) + [1], + strides=self._strides, + rates=rates, padding=self._padding) - flatten_size = (filter_height * filter_width * in_channels) - # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde - # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), - # where M = minibatch size, |T| = number of spatial locations, - # |Delta| = number of spatial offsets, and J = number of input maps - # for convolutional layer l. - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - # We append a homogenous coordinate to patches_flat if the layer has - # bias parameters. This gives us [[A_l]]_H from the paper. - if self._has_bias: - patches_flat = append_homog(patches_flat) - # We call compute_cov without passing in a normalizer. compute_cov uses - # the first dimension of patches_flat i.e. M|T| as the normalizer by - # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with - # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from - # the paper but has a different scale here for consistency with - # ConvOutputKroneckerFactor. - # (Tilde omitted over A for clarity.) - return compute_cov(patches_flat) + elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": + assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] + assert self._filter_shape[0] == self._filter_shape[1] == 1 + patches = utils.extract_pointwise_conv2d_patches( + inputs, self._filter_shape, data_format=None) + + else: + raise NotImplementedError(self._extract_patches_fn) + + flatten_size = np.prod(self._filter_shape[0:-1]) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. + if self._has_bias: + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) + + def _get_data_device(self, tower): + return self._inputs[tower].device class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -902,20 +1288,28 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, data_format=None): """Initializes ConvOutputKroneckerFactor. Args: - outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, height, width, out_channels]. + outputs_grads: List of list of Tensors. Each Tensor is of shape + [batch_size, ..spatial_input_size.., out_channels]. First list index + is source, the second is tower. + data_format: None or str. Format of outputs_grads. + + Raises: + ValueError: If channels are not final dimension. """ - self._out_channels = outputs_grads[0].shape.as_list()[3] + if not utils.is_data_format_channel_last(data_format): + raise ValueError("Channel must be last.") + self._out_channels = outputs_grads[0][0].shape.as_list()[-1] self._outputs_grads = outputs_grads super(ConvOutputKroneckerFactor, self).__init__() @property def _var_scope(self): - return "ff_convoutkron/" + scope_string_from_params(self._outputs_grads) + return "ff_convoutkron_" + scope_string_from_params( + nest.flatten(self._outputs_grads)) @property def _cov_shape(self): @@ -926,134 +1320,150 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _num_towers(self): + return len(self._outputs_grads[0]) + @property def _dtype(self): - return self._outputs_grads[0].dtype - - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._outputs_grads[idx]): - # reshaped_tensor below is the matrix DS_l defined in the KFC paper - # (tilde omitted over S for clarity). It has shape M|T| x I, where - # M = minibatch size, |T| = number of spatial locations, and - # I = number of output maps for convolutional layer l. - reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], - [-1, self._out_channels]) - # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, - # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l - # as defined in the paper, with shape I x I. - # (Tilde omitted over S for clarity.) - return compute_cov(reshaped_tensor) - - -class FullyConnectedMultiKF(InverseProvidingFactor): - """Kronecker factor for a fully connected recurrent layer.""" + return self._outputs_grads[0][0].dtype + + def _compute_new_cov(self, source, tower): + outputs_grad = self._outputs_grads[source][tower] + + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. + reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + def _get_data_device(self, tower): + return self._outputs_grads[0][tower].device + + +class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): + """Kronecker factor for a fully connected layer used multiple times.""" def __init__(self, - tensor_lists, + tensors, + num_uses=None, has_bias=False): """Constructs a new `FullyConnectedMultiKF`. Args: - tensor_lists: List of lists of Tensors of shape [batch_size, n]. + tensors: List of list of Tensors of shape, each of shape + [num_uses * batch_size, n], and is a reshape version of a Tensor of + shape [num_uses, batch_size, n]. Each of these tensors is usually a + layer's inputs or its output's gradients. The first list index is + sources, the second is towers. + num_uses: int. The number of time-steps / uses. has_bias: bool. If True, '1' is appended to each row. """ - self._tensor_lists = tensor_lists - self._has_bias = has_bias - self._batch_size = array_ops.shape(tensor_lists[0][0])[0] - self._num_timesteps = len(tensor_lists[0]) - self._tensors = [None] * len(tensor_lists) + self._num_uses = num_uses self._cov_dt1 = None + self._make_cov_dt1 = False self._option1quants_by_damping = {} self._option2quants_by_damping = {} + self._option1quants_registrations = set() + self._option2quants_registrations = set() - super(FullyConnectedMultiKF, self).__init__() + super(FullyConnectedMultiKF, self).__init__(tensors=tensors, + has_bias=has_bias) @property - def _var_scope(self): - return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists) + def _num_timesteps(self): + return self._num_uses @property - def _num_sources(self): - return len(self._tensor_lists) - - @property - def _dtype(self): - return self._tensor_lists[0][0].dtype + def _var_scope(self): + return "ff_fc_multi_" + scope_string_from_params( + tuple(nest.flatten(self._tensors)) + + (self._num_timesteps, self._has_bias,)) def make_covariance_update_op(self, ema_decay): op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) if self._cov_dt1 is not None: - new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) - for idx in range(self._num_sources)) + new_cov_dt1_contribs = [] + for source in range(self._num_sources): + for tower in range(self._num_towers): + with place_on_device(self._get_data_device(tower)): + new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, + tower)) - with maybe_colocate_with(new_cov_dt1_contribs[0]): - new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) + / float(self._num_towers)) - op2 = moving_averages.assign_moving_average( - self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + # See comments in FisherFactor.make_covariance_update_op() for details. + if utils.on_tpu(): + new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) + + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) - # TODO(b/69112164): - # It's important that _cov and _cov_dt1 remain consistent with each - # other while the inverse ops are happening. How can we ensure this? - # We will need to add explicit synchronization for this to - # work with asynchronous training. - op = control_flow_ops.group(op, op2) + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) return op - def _compute_new_cov(self, idx=0): - with maybe_colocate_with(self._tensor_lists[idx]): - tensor = array_ops.concat(self._tensor_lists[idx], 0) - if self._has_bias: - tensor = append_homog(tensor) - # We save these so they can be used by _compute_new_cov_dt1 - self._tensors[idx] = tensor - return compute_cov(tensor) - - def _compute_new_cov_dt1(self, idx=0): - tensor = self._tensors[idx] - with maybe_colocate_with(tensor): - # Is there a more elegant way to do this computation? - tensor_present = tensor[:-self._batch_size, :] - tensor_future = tensor[self._batch_size:, :] - # We specify a normalizer for this computation to ensure a PSD Fisher - # block estimate. This is equivalent to padding with zeros, as was done - # in Section B.2 of the appendix. - normalizer = self._num_timesteps * self._batch_size - return compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=normalizer) + def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring + tensor = self._tensors[source][tower] + if self._has_bias: + # This appending is technically done twice (the other time is for + # _compute_new_cov()) + tensor = append_homog(tensor) - @property - def _cov_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias - return [size, size] + total_len = array_ops.shape(tensor)[0] + batch_size = total_len // self._num_timesteps + + tensor_present = tensor[:-batch_size, :] + tensor_future = tensor[batch_size:, :] + + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=total_len) + + def _get_data_device(self, tower): + return self._tensors[0][tower].device @property def _vec_shape(self): - size = self._tensor_lists[0][0].shape[1] + self._has_bias + size = self._tensors[0][0].shape[1] + self._has_bias return [size] - def get_option1quants(self, damping): - return self._option1quants_by_damping[damping] + def get_option1quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option1quants_by_damping[damping_id] - def get_option2quants(self, damping): - return self._option2quants_by_damping[damping] + def get_option2quants(self, damping_func): + damping_id = graph_func_to_id(damping_func) + return self._option2quants_by_damping[damping_id] def get_cov_dt1(self): assert self._cov_dt1 is not None return self._cov_dt1 def register_cov_dt1(self): - """Create a variable representing temporal cross-covariance. + self._make_cov_dt1 = True - (This is technically the second moment, not covariance, since it's - not mean subtracted.) - """ - if self._cov_dt1 is None: + def instantiate_cov_variables(self): + super(FullyConnectedMultiKF, self).instantiate_cov_variables() + assert self._cov_dt1 is None + if self._make_cov_dt1: with variable_scope.variable_scope(self._var_scope): self._cov_dt1 = variable_scope.get_variable( "cov_dt1", @@ -1062,15 +1472,25 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - def register_option1quants(self, damping): + def register_option1quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option1quants_registrations: + self._option1quants_registrations.add(damping_id) + + def register_option2quants(self, damping_func): + damping_id = self._register_damping(damping_func) + if damping_id not in self._option2quants_registrations: + self._option2quants_registrations.add(damping_id) - self.register_cov_dt1() + def instantiate_inv_variables(self): + super(FullyConnectedMultiKF, self).instantiate_inv_variables() - if damping not in self._option1quants_by_damping: + for damping_id in self._option1quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. - damping_string = scalar_or_tensor_to_string(damping) with variable_scope.variable_scope(self._var_scope): Lmat = variable_scope.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), @@ -1085,17 +1505,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - self._option1quants_by_damping[damping] = (Lmat, psi) - - def register_option2quants(self, damping): - - self.register_cov_dt1() + assert damping_id not in self._option1quants_by_damping + self._option1quants_by_damping[damping_id] = (Lmat, psi) - if damping not in self._option2quants_by_damping: + for damping_id in self._option2quants_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) # It's questionable as to whether we should initialize with stuff like # this at all. Ideally these values should never be used until they are # updated at least once. - damping_string = scalar_or_tensor_to_string(damping) with variable_scope.variable_scope(self._var_scope): Pmat = variable_scope.get_variable( # pylint: disable=invalid-name "Lmat_damp{}".format(damping_string), @@ -1116,14 +1534,15 @@ class FullyConnectedMultiKF(InverseProvidingFactor): trainable=False, dtype=self._dtype) - self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + assert damping_id not in self._option2quants_by_damping + self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" # TODO(b/69918258): Add correctness tests for this method. # pylint: disable=invalid-name - ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + ops = [] if (len(self._option1quants_by_damping) + len(self._option2quants_by_damping)): @@ -1144,8 +1563,10 @@ class FullyConnectedMultiKF(InverseProvidingFactor): # consistently, or are somehow read between or during the cov updates. # Can this possibly happen? Is there a way to prevent it? - for damping, (Lmat_var, - psi_var) in self._option1quants_by_damping.items(): + for damping_id, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() invsqrtC0 = math_ops.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) @@ -1170,8 +1591,10 @@ class FullyConnectedMultiKF(InverseProvidingFactor): ops.append(Lmat_var.assign(Lmat)) ops.append(psi_var.assign(psi)) - for damping, (Pmat_var, Kmat_var, - mu_var) in self._option2quants_by_damping.items(): + for damping_id, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + damping = self._damping_funcs_by_id[damping_id]() # compute C0^(-1/2) invsqrtC0 = math_ops.matmul( @@ -1212,6 +1635,7 @@ class FullyConnectedMultiKF(InverseProvidingFactor): ops.append(Kmat_var.assign(Kmat)) ops.append(mu_var.assign(mu)) + ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() return [control_flow_ops.group(*ops)] # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index ad93919149c287b1932dd2b6bd772c0dab26192d..2d8e378a932c16d48360bc4b15ff4f3239c0ed1f 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -24,26 +24,15 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ - "inverse_initializer", - "covariance_initializer", - "diagonal_covariance_initializer", - "scope_string_from_params", - "scope_string_from_name", - "scalar_or_tensor_to_string", - "FisherFactor", - "InverseProvidingFactor", - "FullFactor", - "DiagonalFactor", - "NaiveDiagonalFactor", - "FullyConnectedDiagonalFactor", - "FullyConnectedKroneckerFactor", - "ConvInputKroneckerFactor", - "ConvOutputKroneckerFactor", - "ConvDiagonalFactor", - "set_global_constants", - "maybe_colocate_with", - "compute_cov", - "append_homog" + "inverse_initializer", "covariance_initializer", + "diagonal_covariance_initializer", "scope_string_from_params", + "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", + "InverseProvidingFactor", "FullFactor", "DiagonalFactor", + "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", + "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", + "compute_cov", "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 8d450f04f379701e46a18b2e34bbbd6fcfcce2bb..366e2a82d56602de0df706cbd382c21aba5540af 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from contextlib import contextmanager from functools import partial +import warnings import math import six @@ -59,6 +61,10 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +_EMBEDDING_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB +} + APPROX_KRONECKER_INDEP_NAME = "kron_indep" APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" @@ -71,10 +77,39 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { option=2) } -# Possible value for 'reuse' keyword argument. Sets 'reuse' to +_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB +} + +_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB +} + +# Possible value for `reuse` keyword argument. Sets `reuse` to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" +_DEFAULT_LAYER_COLLECTION = None + + +def get_default_layer_collection(): + """Get default LayerCollection.""" + if _DEFAULT_LAYER_COLLECTION is None: + raise ValueError( + "Attempted to retrieve default LayerCollection when none is set. Use " + "LayerCollection.as_default().") + + return _DEFAULT_LAYER_COLLECTION + + +def set_default_layer_collection(layer_collection): + global _DEFAULT_LAYER_COLLECTION + + if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: + raise ValueError("Default LayerCollection is already set.") + + _DEFAULT_LAYER_COLLECTION = layer_collection + class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -130,11 +165,16 @@ class LayerCollection(object): fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. losses: a list of LossFunction objects. The loss to be optimized is their sum. + loss_colocation_ops: ops to colocate loss function evaluations with. These + will typically be the inputs to the losses. """ def __init__(self, graph=None, name="LayerCollection"): + warnings.warn( + "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. " + "Use https://pypi.python.org/pypi/kfac instead.") self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() self._linked_parameters = dict( @@ -143,18 +183,29 @@ class LayerCollection(object): self._loss_dict = {} # {str: LossFunction} self._subgraph = None self._default_generic_approximation = APPROX_FULL_NAME + self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_conv2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_SERIES_2_NAME) + APPROX_KRONECKER_INDEP_NAME) + self._default_conv2d_multi_approximation = ( + APPROX_KRONECKER_INDEP_NAME) + self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME + self.loss_colocation_ops = {} + self._vars_to_uses = defaultdict(lambda: 0) with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @property def losses(self): - """LossFunctions registered with this LayerCollection.""" - return list(self._loss_dict.values()) + """Tuple of LossFunction objects registered with this LayerCollection.""" + return nest.flatten(self.towers_by_loss) + + @property + def towers_by_loss(self): + """Tuple across losses of LossFunction objects registered to each tower.""" + return tuple(tuple(lst) for lst in self._loss_dict.values()) @property def registered_variables(self): @@ -178,6 +229,17 @@ class LayerCollection(object): """ return self._linked_parameters + @property + def default_embedding_approximation(self): + return self._default_embedding_approximation + + def set_default_embedding_approximation(self, value): + if value != APPROX_KRONECKER_NAME: + raise ValueError( + "{} is not a valid approximation for embedding variables.".format( + value)) + self._default_embedding_approximation = value + @property def default_generic_approximation(self): return self._default_generic_approximation @@ -202,14 +264,14 @@ class LayerCollection(object): @property def default_conv2d_approximation(self): - return self._default_convolution_2d_approximation + return self._default_conv2d_approximation def set_default_conv2d_approximation(self, value): if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError( "{} is not a valid approximation for 2d convolutional layers.".format( value)) - self._default_convolution_2d_approximation = value + self._default_conv2d_approximation = value @property def default_fully_connected_multi_approximation(self): @@ -221,6 +283,14 @@ class LayerCollection(object): "multi layer.".format(value)) self._default_fully_connected_multi_approximation = value + @property + def default_conv2d_multi_approximation(self): + return self._default_conv2d_multi_approximation + + @property + def default_embedding_multi_approximation(self): + return self._default_embedding_multi_approximation + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -228,8 +298,8 @@ class LayerCollection(object): layer_key: A variable or tuple of variables. The key to check for in existing registrations and to register if valid. fisher_block: The associated `FisherBlock`. - reuse: Method to use for inserting new `FisherBlock`s. One of True, False, - or 'VARIABLE_SCOPE'. + reuse: Method to use for inserting new `FisherBlock's. One of True, False, + or `VARIABLE_SCOPE`. Raises: ValueError: If `layer_key` was already registered and reuse is `False`, @@ -278,23 +348,73 @@ class LayerCollection(object): self.fisher_blocks[layer_key] = fisher_block return fisher_block - def get_use_count_map(self): - """Returns a dict of variables to their number of registrations.""" - # TODO(b/70283403): Reimplement this in the old way, where each - # registration function would be responsible for incrementing the count. - # Also, this version has a bug: it won't do the right thing for generic - # registration for parameters that are shared. i.e. it won't set the use - # count to infinity. - vars_to_uses = defaultdict(int) - for key, block in six.iteritems(self.fisher_blocks): - n = ( - block.num_inputs()*block.num_registered_minibatches if isinstance( - block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) - else block.num_registered_minibatches) - key = utils.ensure_sequence(key) - for k in key: - vars_to_uses[k] += n - return vars_to_uses + def register_loss_function(self, + loss, + colocation_op, + base_name, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a LossFunction object. + + Args: + loss: The LossFunction object. + colocation_op: The op to colocate the loss function's computations with. + base_name: The name to derive a new unique name from is the name argument + is None. + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional + tower for the existing loss function. + + Raises: + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with `name` found. + KeyError: If reuse == False and existing LossFunction with `name` found. + """ + + name = name or self._graph.unique_name(base_name) + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + + loss_list = self._loss_dict.get(name, None) + + if loss_list is None: + raise KeyError( + "Unable to find loss function named {}. Register a new loss " + "function with reuse=False.".format(name)) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another tower.".format(name)) + + loss_list = [] + self._loss_dict[name] = loss_list + + loss_list.append(loss) + self.loss_colocation_ops[loss] = colocation_op + + def _get_use_count_map(self): + """Returns a dict mapping variables to their number of registrations.""" + return self._vars_to_uses + + def _add_uses(self, params, uses): + """Register additional uses by params in the graph. + + Args: + params: Variable or tuple of Variables. Parameters for a layer. + uses: int or float. Number of additional uses for these parameters. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + for var in params: + self._vars_to_uses[var] += uses def check_registration(self, variables): """Checks that all variable uses have been registered properly. @@ -312,7 +432,7 @@ class LayerCollection(object): # Note that overlapping parameters (i.e. those that share variables) will # be caught by layer_collection.LayerParametersDict during registration. - reg_use_map = self.get_use_count_map() + reg_use_map = self._get_use_count_map() error_messages = [] @@ -374,24 +494,24 @@ class LayerCollection(object): """ params = frozenset(utils.ensure_sequence(params)) - # Check if any of the variables in 'params' is already in - # 'self.fisher_blocks.keys()'. + # Check if any of the variables in `params` is already in + # 'self.fisher_blocks.keys()`. for registered_params, fisher_block in self.fisher_blocks.items(): registered_params_set = set(utils.ensure_sequence(registered_params)) for variable in params: if (variable in registered_params_set and params != registered_params_set): raise ValueError( - "Can't link parameters {}, variable {} was already registered in " + "Can`t link parameters {}, variable {} was already registered in " "group {} with layer {}".format(params, variable, registered_params, fisher_block)) - # Check if any of the variables in 'params' is already in - # 'self.linked_parameters'. + # Check if any of the variables in `params` is already in + # 'self.linked_parameters`. for variable in params: for other_linked_params in self.linked_parameters: if variable in other_linked_params: - raise ValueError("Can't link parameters {}, variable {} was already " + raise ValueError("Can`t link parameters {}, variable {} was already " "linked in group {}.".format(params, variable, other_linked_params)) self._linked_parameters[params] = approximation @@ -402,12 +522,27 @@ class LayerCollection(object): inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) self._subgraph = utils.SubGraph(inputs_to_losses) + def eval_losses(self): + """Return evaluated losses (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate()) + return evals + + def eval_losses_on_samples(self): + """Return losses evaluated on samples (colocated with inputs to losses).""" + evals = [] + for loss in self.losses: + with ops.colocate_with(self.loss_colocation_ops[loss]): + evals.append(loss.evaluate_on_sample()) + return evals + def total_loss(self): - return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses)) + return math_ops.add_n(self.eval_losses()) def total_sampled_loss(self): - return math_ops.add_n( - tuple(loss.evaluate_on_sample() for loss in self.losses)) + return math_ops.add_n(self.eval_losses_on_samples()) def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" @@ -417,6 +552,57 @@ class LayerCollection(object): else: return None + def _get_block_type(self, params, approx, default, approx_to_type): + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = default + + if approx not in approx_to_type: + raise ValueError("Bad value {} for approx.".format(approx)) + + return approx_to_type[approx], approx + + def register_embedding(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers an embedding layer. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices + into embedding matrix. + outputs: Tensor of shape [batch_size, embedding_size]. Outputs + produced by layer. + approx: str or None. If not None must be "kron". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_approximation, + _EMBEDDING_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + block = self.register_block( + params, block_type(self, vocab_size), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + def register_fully_connected(self, params, inputs, @@ -432,29 +618,31 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_approximation - if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_approximation, + _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias), + reuse=reuse) + block.register_additional_tower(inputs, outputs) - block = self.register_block(params, block_type(self, has_bias), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + self._add_uses(params, 1) def register_conv2d(self, params, @@ -462,44 +650,262 @@ class LayerCollection(object): padding, inputs, outputs, + data_format=None, + dilations=None, approx=None, reuse=VARIABLE_SCOPE): - """Registers a convolutional layer. + """Registers a call to tf.nn.conv2d(). Args: params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [kernel_height, kernel_width, in_channels, out_channels]. Bias should have shape [out_channels]. - strides: 1-D Tensor of length 4. Strides for convolution kernel. + strides: List of 4 ints. Strides for convolution kernel. padding: string. see tf.nn.conv2d for valid values. inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. Output produced by layer. - approx: str. One of "kron" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_conv2d_approximation + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_approximation, + _CONV2D_APPROX_TO_BLOCK_TYPES) + + # It feels bad to pass in configuration that has to do with the internal + # implementation. And then we can`t use the same constructor for both + # anymore and are thus forced to use this ugly if-statement. + # TODO(b/74793309): Clean this up? + if approx == APPROX_KRONECKER_NAME: + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches"), + reuse=reuse) + elif approx == APPROX_DIAGONAL_NAME: + assert strides[0] == strides[-1] == 1 + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilations=dilations, + data_format=data_format), + reuse=reuse) + else: + raise NotImplementedError(approx) - if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_convolution(self, + params, + inputs, + outputs, + padding, + strides=None, + dilation_rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.convolution(). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [..filter_spatial_size.., + in_channels, out_channels]. Bias should have shape [out_channels]. + inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. + Inputs to layer. + outputs: Tensor of shape [batch_size, ..output_spatial_size.., + out_channels]. Output produced by layer. + padding: string. see tf.nn.conv2d for valid values. + strides: List of ints of length len(..input_spatial_size..). Strides for + convolution kernel in spatial dimensions. + dilation_rate: List of ints of length len(..input_spatial_size..). + Dilations along spatial dimension. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_KRONECKER_NAME + + block = self.register_block( + params, + fb.ConvKFCBasicFB( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + dilation_rate=dilation_rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_depthwise_conv2d(self, + params, + inputs, + outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.depthwise_conv2d(). + + Args: + params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Convolutional filter. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_height, output_width, + in_channels * channel_multiplier]. Output produced by depthwise conv2d. + strides: List of ints of length 4. Strides along all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rates in spatial + dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must "diagonal". The Fisher + approximation to use. If None the default value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + # TODO(b/74793309): Have this use _get_block_type like the other + # registration functions? + assert approx is None or approx == APPROX_DIAGONAL_NAME + assert data_format in [None, "NHWC"] - block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block( - params, block_type(self, params, strides, padding), reuse=reuse) - block.register_additional_minibatch(inputs, outputs) + params, + fb.DepthwiseConvDiagonalFB( + layer_collection=self, + params=params, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + + self._add_uses(params, 1) + + def register_separable_conv2d(self, + depthwise_params, + pointwise_params, + inputs, + depthwise_outputs, + pointwise_outputs, + strides, + padding, + rate=None, + data_format=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Register a call to tf.nn.separable_conv2d(). + + Note: This requires access to intermediate outputs between depthwise and + pointwise convolutions. + + Args: + depthwise_params: 4-D Tensor of shape [filter_height, filter_width, + in_channels, channel_multiplier]. Filter for depthwise conv2d. + pointwise_params: 4-D Tensor of shape [1, 1, in_channels * + channel_multiplier, out_channels]. Filter for pointwise conv2d. + inputs: Tensor of shape [batch_size, input_height, input_width, + in_channels]. Inputs to layer. + depthwise_outputs: Tensor of shape [batch_size, output_height, + output_width, in_channels * channel_multiplier]. Output produced by + depthwise conv2d. + pointwise_outputs: Tensor of shape [batch_size, output_height, + output_width, out_channels]. Output produced by pointwise conv2d. + strides: List of ints of length 4. Strides for depthwise conv2d kernel in + all dimensions. + padding: string. see tf.nn.conv2d for valid values. + rate: None or List of ints of length 2. Dilation rate of depthwise conv2d + kernel in spatial dimensions. + data_format: str or None. Format of data. + approx: str or None. If not None must be one of "kron" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + self.register_depthwise_conv2d( + params=depthwise_params, + inputs=inputs, + outputs=depthwise_outputs, + strides=strides, + padding=padding, + rate=rate, + data_format=data_format, + approx=APPROX_DIAGONAL_NAME, + reuse=reuse) + + self.register_conv2d( + params=pointwise_params, + inputs=depthwise_outputs, + outputs=pointwise_outputs, + strides=[1, 1, 1, 1], + padding="VALID", + data_format=data_format, + approx=approx, + reuse=reuse) def register_generic(self, params, @@ -510,32 +916,32 @@ class LayerCollection(object): Args: params: Tensor or tuple of Tensors corresponding to the parameters. - batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of "full" or "diagonal". - reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. + batch_size: 0-D Tensor. Size of the minibatch (for this tower). + approx: str or None. It not None, must be one of "full" or "diagonal". + The Fisher approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `batch_size` to the total + mini-batch size use when estimating the Fisher block for this layer + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. - KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ + block_type, approx = self._get_block_type( + params, approx, self.default_generic_approximation, + _GENERIC_APPROX_TO_BLOCK_TYPES) - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_generic_approximation - - if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - - block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] block = self.register_block(params, block_type(self, params), reuse=reuse) - block.register_additional_minibatch(batch_size) + block.register_additional_tower(batch_size) + + self._add_uses(params, float("inf")) def register_fully_connected_multi(self, params, inputs, outputs, - approx=None): + num_uses=None, approx=None, + reuse=VARIABLE_SCOPE): """Register fully connected layers with shared parameters. This can handle general fully-connected layers with shared parameters, but @@ -546,34 +952,194 @@ class LayerCollection(object): params: Tensor or 2-tuple of Tensors corresponding to weight and bias of this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. - inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs - to layer. In the case of RNNs, one Tensor per time step. - outputs: A list of tensors, the same length as 'inputs', each of shape - [batch_size, output_size]. Outputs produced by layer. In the case of - RNNs, one Tensor per time step. - approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs + to layer. The list indexes each use in the graph (which might + correspond to a "time-step" in an RNN). OR, can be single Tensor, of + shape [num_uses * batch_size , input_size], which is a reshaped version + of a Tensor of shape [num_uses, batch_size, input_size]. + outputs: A list of Tensors, the same length as `inputs`, each of shape + [batch_size, output_size]. Outputs produced by layer. The list indexes + each use in the graph (which might correspond to a "time-step" in an + RNN). Needs to correspond with the order used in `inputs`. OR, can be + a single Tensor of shape [num_uses * batch_size, output_size], which is + a reshaped version of a Tensor of shape [num_uses, batch_size, + output_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None, must be of "kron_indep", "kron_series_1" + or "kron_series_2". The Fisher approximation to use. If None the default + value is used. (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") Raises: - ValueError: For improper value to 'approx'. + ValueError: For improper value to `approx`. """ - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = self.default_fully_connected_multi_approximation - has_bias = isinstance(params, (tuple, list)) + block_type, approx = self._get_block_type( + params, approx, self.default_fully_connected_multi_approximation, + _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) # TODO(b/70283649): something along the lines of find_canonical_output # should be added back in here (and for the other block types, arguably). - if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("Bad value {} for approx.".format(approx)) - block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + has_bias = isinstance(params, (tuple, list)) + block = self.register_block(params, block_type(self, has_bias=has_bias, + num_uses=num_uses), + reuse=reuse) + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + def register_conv2d_multi(self, + params, + strides, + padding, + inputs, + outputs, + num_uses=None, + data_format=None, + dilations=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers convolutional layers with shared parameters. - # For now we don't support multiple minibatches for this type of layer, so - # we set reuse=False - self.register_block(params, - block_type(self, inputs, outputs, has_bias=has_bias), - reuse=False) + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: A list of Tensors, each of shape [batch_size, height, width, + in_channels]. Inputs to layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). OR, can be single + Tensor, of shape [num_uses * batch_size, height, width, in_channels], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + height, width, in_channels]. + outputs: A list of Tensors, each of shape [batch_size, height, width, + out_channels]. Output produced by layer. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + Needs to correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, height, width, + out_channels], which is a reshaped version of a Tensor of shape + [num_uses, batch_size, height, width, out_channels]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + data_format: str or None. Format of data. + dilations: List of 4 ints. Dilations along each dimension. + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_conv2d_multi_approximation, + _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) + + block = self.register_block( + params, + block_type( + layer_collection=self, + params=params, + padding=padding, + strides=strides, + data_format=data_format, + dilation_rate=dilations, + extract_patches_fn="extract_image_patches", + num_uses=num_uses), + reuse=reuse) + + block.register_additional_tower(inputs, outputs) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) + + # TODO(b/74108452): change the loss registration functions names to refer + # to "loss functions" instead of distributions. Following naming convention + # of the loss function classes themselves. + + def register_embedding_multi(self, + params, + inputs, + outputs, + num_uses=None, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers embedding layers with shared parameters. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: A list of Tensors, each of shape [batch_size, input_size] and + dtype int32. Indices into embedding matrix. The list indexes each use + in the graph (which might correspond to a "time-step" in an RNN). + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], + which is a reshaped version of a Tensor of shape [num_uses, batch_size, + input_size]. + outputs: A list of Tensors, each of shape [batch_size, embedding_size]. + Outputs produced by layer. The list indexes each use in the graph + (which might correspond to a "time-step" in an RNN). Needs to + correspond with the order used in `inputs`. OR, can be a + single Tensor, of shape [num_uses * batch_size, embedding_size], which + is a reshaped version of a Tensor of shape [num_uses, batch_size, + embedding_size]. + num_uses: int or None. The number uses/time-steps in the graph where the + layer appears. Only needed if both inputs and outputs are given in the + single Tensor format. (Default: None) + approx: str or None. If not None must by "kron_indep". The Fisher + approximation to use. If None the default value is used. + (Default: None) + reuse: bool or str. If True, this adds `inputs` and `outputs` as an + additional mini-batch/tower of data to use when estimating the Fisher + block for this layer (which must have already been registered). If + "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the + word `use` here has a completely different meaning to "use in the graph" + as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.) + (Default: "VARIABLE_SCOPE") + + Raises: + ValueError: For improper value to `approx`. + KeyError: If reuse == True but no FisherBlock found for `params`. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + block_type, approx = self._get_block_type( + params, approx, self.default_embedding_multi_approximation, + _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + vocab_size = int(params.shape[0]) + + block = self.register_block( + params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) + block.register_additional_tower(inputs, outputs) + + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_categorical_predictive_distribution(self, logits, @@ -593,53 +1159,24 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. - If False, create a new FisherBlock. If VARIABLE_SCOPE, use - tf.get_variable_scope().reuse. - - Raises: - ValueError: If reuse == True and name == None. - ValueError: If reuse == True and seed != None. - KeyError: If reuse == True and no existing LossFunction with 'name' found. - KeyError: If reuse == False and existing LossFunction with 'name' found. + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_categorical_predictive_distribution") - - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - if name is None: - raise ValueError( - "If reuse is enabled, loss function's name must be set.") - if seed is not None: - raise ValueError( - "Seed can only be specified at LossFunction instantiation.") - - loss = self._loss_dict.get(name, None) - - if loss is None: - raise KeyError( - "Unable to find loss function named {}. Create a new LossFunction " - "with reuse=False.".format(name)) - - loss.register_additional_minibatch(logits, targets=targets) - else: - if name in self._loss_dict: - raise KeyError( - "Loss function named {} already exists. Set reuse=True to append " - "another minibatch.".format(name)) - loss = lf.CategoricalLogitsNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "categorical_predictive_distribution", + name=name, reuse=reuse) def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None, - name=None): + name=None, + reuse=VARIABLE_SCOPE): """Registers a normal predictive distribution. Args: @@ -656,21 +1193,23 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `mean` and `var` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_normal_predictive_distribution") - if name in self._loss_dict: - raise NotImplementedError( - "Adding logits to an existing LossFunction not yet supported.") - loss = lf.NormalMeanNegativeLogProbLoss( - mean, var, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, + seed=seed) + self.register_loss_function(loss, mean, + "normal_predictive_distribution", + name=name, reuse=reuse) def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, targets=None, - name=None): + name=None, + reuse=VARIABLE_SCOPE): """Registers a multi-Bernoulli predictive distribution. Args: @@ -683,29 +1222,30 @@ class LayerCollection(object): (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) + reuse: bool or str. If True, this adds `logits` as an additional + mini-batch/tower of inputs to the loss-function/predictive distribution + (which must have already been registered). If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") """ - name = name or self._graph.unique_name( - "register_multi_bernoulli_predictive_distribution") - if name in self._loss_dict: - raise NotImplementedError( - "Adding logits to an existing LossFunction not yet supported.") - loss = lf.MultiBernoulliNegativeLogProbLoss( - logits, targets=targets, seed=seed) - self._loss_dict[name] = loss + loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, + seed=seed) + self.register_loss_function(loss, logits, + "multi_bernoulli_predictive_distribution", + name=name, reuse=reuse) def make_or_get_factor(self, cls, args): - """Insert 'cls(args)' into 'self.fisher_factors' if not already present. + """Insert `cls(args)` into 'self.fisher_factors` if not already present. - Wraps constructor in 'tf.variable_scope()' to ensure variables constructed - in 'cls.__init__' are placed under this LayerCollection's scope. + Wraps constructor in `tf.variable_scope()` to ensure variables constructed + in `cls.__init__` are placed under this LayerCollection's scope. Args: cls: Class that implements FisherFactor. - args: Tuple of arguments to pass into 'cls's constructor. Must be + args: Tuple of arguments to pass into `cls's constructor. Must be hashable. Returns: - Instance of 'cls' found in self.fisher_factors. + Instance of `cls` found in self.fisher_factors. """ try: hash(args) @@ -720,3 +1260,10 @@ class LayerCollection(object): with variable_scope.variable_scope(self._var_scope): self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] + + @contextmanager + def as_default(self): + """Sets this LayerCollection as the default.""" + set_default_layer_collection(self) + yield + set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index f8aa230d9ca1f542950f56b1e6cf1ab7ccd3d05f..9f4685380705bd409dbcd7e85d0e3bb4189a6adc 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "get_default_layer_collection", + "set_default_layer_collection", "LayerParametersDict", "LayerCollection", "APPROX_KRONECKER_NAME", diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index cb3e698b9ceab920785adf735f88bd8e535a628f..42d525c2c21f5ba3457cba041261dc3b225dc11e 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -57,30 +57,6 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass - @property - def input_minibatches(self): - """A `list` of inputs to the loss function, separated by minibatch. - - Typically there will be one minibatch per tower in a multi-tower setup. - Returns a list consisting of `self.inputs` by default; `LossFunction`s - supporting registering multiple minibatches should override this method. - - Returns: - A `list` of `Tensor`s representing - """ - return [self.inputs] - - @property - def num_registered_minibatches(self): - """Number of minibatches registered for this LossFunction. - - Typically equal to the number of towers in a multi-tower setup. - - Returns: - An `int` representing the number of registered minibatches. - """ - return len(self.input_minibatches) - def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -474,7 +450,6 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): assert len(variance.shape) == 2, "Expect 2D variance tensor." self._mean = mean self._variance = variance - self._scale = math_ops.sqrt(variance) self._targets = targets super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) @@ -484,7 +459,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def dist(self): - return normal.Normal(loc=self._mean, scale=self._scale) + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) @property def params(self): @@ -502,7 +477,7 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean_factor(self): - return 1. / self._scale + return 1. / math_ops.sqrt(self._variance) @property def _fisher_var(self): @@ -611,36 +586,13 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, index in [0, output_size). seed: int or None. Default random seed when sampling. """ - self._logits_components = [] - self._targets_components = [] - self.register_additional_minibatch(logits, targets=targets) + self._logits = logits + self._targets = targets super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) - def register_additional_minibatch(self, logits, targets=None): - """Register an additiona minibatch's worth of parameters. - - Args: - logits: Tensor of shape [batch_size, output_size]. Parameters for - underlying distribution. - targets: None or Tensor of shape [batch_size, output_size]. Each row must - be a one-hot vector. - """ - self._logits_components.append(logits) - self._targets_components.append(targets) - - @property - def _logits(self): - return array_ops.concat(self._logits_components, axis=0) - - @property - def input_minibatches(self): - return self._logits_components - @property def targets(self): - if all(target is None for target in self._targets_components): - return None - return array_ops.concat(self._targets_components, axis=0) + return self._targets @property def dist(self): @@ -661,19 +613,19 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def multiply_fisher(self, vector): probs = self._probs return vector * probs - probs * math_ops.reduce_sum( - vector * probs, axis=-1, keep_dims=True) + vector * probs, axis=-1, keepdims=True) def multiply_fisher_factor(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=-1, keep_dims=True) + sqrt_probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_transpose(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=-1, keep_dims=True) + probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py index 705a871d482565897e7ac850327729a6186f1746..4279cb2792854249e3e076d200e2656bc615779d 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -33,7 +33,6 @@ _allowed_symbols = [ "CategoricalLogitsNegativeLogProbLoss", "OnehotCategoricalLogitsNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", - "MultiBernoulliNegativeLogProbLoss", "insert_slice_in_zeros", ] diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 1974b07acfc879dc4bc844db9af88fd1043d6698..f01c5a832212f88d80529672b652ca04d45c0f0e 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,16 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est # pylint enable=long-line +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import gradient_descent @@ -47,8 +51,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): name="KFAC", estimation_mode="gradients", colocate_gradients_with_ops=True, - cov_devices=None, - inv_devices=None): + batch_size=None, + placement_strategy=None, + **kwargs): """Initializes the KFAC optimizer with the given settings. Args: @@ -61,6 +66,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. + If damping is adapted during training then this value is used for + initializing damping varaible. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher @@ -86,12 +93,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True) - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. + batch_size: The size of the mini-batch. Only needed when momentum_type + == 'qmodel' or when automatic adjustment is used. (Default: None) + placement_strategy: string, Device placement strategy used when creating + covariance variables, covariance ops, and inverse ops. + (Default: `None`) + **kwargs: Arguments to be passesd to specific placement + strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. Raises: ValueError: If the momentum type is unsupported. @@ -100,20 +108,32 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ - - variables = var_list - if variables is None: - variables = tf_variables.trainable_variables() - - self._fisher_est = est.FisherEstimator( - variables, - cov_ema_decay, - damping, - layer_collection, - estimation_mode=estimation_mode, - colocate_gradients_with_ops=colocate_gradients_with_ops, - cov_devices=cov_devices, - inv_devices=inv_devices) + # Parameters to be passed to the Fisher estimator: + self._variables = var_list or tf_variables.trainable_variables + self._cov_ema_decay = cov_ema_decay + self._layers = layer_collection + self._estimation_mode = estimation_mode + self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # The below paramaters are required only if damping needs to be adapated. + # These parameters can be set by calling + # set_damping_adaptation_params() explicitly. + self._damping_adaptation_decay = 0.95 + self._damping_adaptation_interval = 5 + # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._adapt_damping = False + self._min_damping = 1e-5 + self._prev_train_batch = None + self._is_chief = False + self._loss_fn = None + self._damping_constant = damping + self._damping = None + self._rho = None + self._prev_loss = None + self._q_model_change = None + self._update_damping_op = None momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] @@ -122,61 +142,224 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): raise ValueError("Unsupported momentum type {}. Must be one of {}." .format(momentum_type, legal_momentum_types)) if momentum_type != "regular" and norm_constraint is not None: - raise ValueError("Update clipping is only supported with momentum" + raise ValueError("Update clipping is only supported with momentum " "type 'regular'.") if momentum_type not in ["regular", "adam"] and momentum != 0: raise ValueError("Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") + # Extra parameters of the optimizer self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint - - # this is a bit of a hack - # TODO(duckworthd): Handle this in a better way (e.g. pass it in?) - self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] - self._losses = layer_collection.losses + self._batch_size = batch_size + self._placement_strategy = placement_strategy + + with variable_scope.variable_scope(name): + self._fisher_est = est.make_fisher_estimator( + placement_strategy=placement_strategy, + variables=self._variables, + cov_ema_decay=self._cov_ema_decay, + damping=self.damping, + layer_collection=self._layers, + exps=(-1,), + estimation_mode=self._estimation_mode, + colocate_gradients_with_ops=self._colocate_gradients_with_ops, + **kwargs) super(KfacOptimizer, self).__init__(learning_rate, name=name) + def set_damping_adaptation_params(self, + is_chief, + prev_train_batch, + loss_fn, + min_damping=1e-5, + damping_adaptation_decay=0.99, + damping_adaptation_interval=5): + """Sets parameters required to adapt damping during training. + + When called, enables damping adaptation according to the Levenberg-Marquardt + style rule described in Section 6.5 of "Optimizing Neural Networks with + Kronecker-factored Approximate Curvature". + + Note that this function creates Tensorflow variables which store a few + scalars and are accessed by the ops which update the damping (as part + of the training op returned by the minimize() method). + + Args: + is_chief: `Boolean`, `True` if the worker is chief. + prev_train_batch: Training data used to minimize loss in the previous + step. This will be used to evaluate loss by calling + `loss_fn(prev_train_batch)`. + loss_fn: `function` that takes as input training data tensor and returns + a scalar loss. + min_damping: `float`(Optional), Minimum value the damping parameter + can take. Default value 1e-5. + damping_adaptation_decay: `float`(Optional), The `damping` parameter is + multipled by the `damping_adaptation_decay` every + `damping_adaptation_interval` number of iterations. Default value 0.99. + damping_adaptation_interval: `int`(Optional), Number of steps in between + updating the `damping` parameter. Default value 5. + + Raises: + ValueError: If `set_damping_adaptation_params` is already called and the + the `adapt_damping` is `True`. + """ + if self._adapt_damping: + raise ValueError("Damping adaptation parameters already set.") + + with variable_scope.variable_scope(self.get_name()): + self._adapt_damping = True + self._is_chief = is_chief + self._prev_train_batch = prev_train_batch + self._loss_fn = loss_fn + self._damping_adaptation_decay = damping_adaptation_decay + self._damping_adaptation_interval = damping_adaptation_interval + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._min_damping = min_damping + + self._rho = variable_scope.get_variable( + "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. + self._prev_loss = variable_scope.get_variable( + "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) + self._q_model_change = variable_scope.get_variable( + "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) + self._damping = variable_scope.get_variable( + "damping", initializer=self._damping_constant, trainable=False) + + @property + def variables(self): + return self._fisher_est.variables + + @property + def damping(self): + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval + @property def cov_update_thunks(self): - return self._fisher_est.cov_update_thunks + self._maybe_make_and_save_everything() + return self._cov_update_thunks @property def cov_update_ops(self): - return self._fisher_est.cov_update_ops + self._maybe_make_and_save_everything() + return self._cov_update_ops @property def cov_update_op(self): - return self._fisher_est.cov_update_op + self._maybe_make_and_save_everything() + return self._cov_update_op @property def inv_update_thunks(self): - return self._fisher_est.inv_update_thunks + self._maybe_make_and_save_everything() + return self._inv_update_thunks @property def inv_update_ops(self): - return self._fisher_est.inv_update_ops + self._maybe_make_and_save_everything() + return self._inv_update_ops @property def inv_update_op(self): - return self._fisher_est.inv_update_op + self._maybe_make_and_save_everything() + return self._inv_update_op - @property - def variables(self): - return self._fisher_est.variables + def _maybe_make_and_save_everything(self): + if not self._fisher_est.made_vars(): + warnings.warn("These convenience properties will be depcrecated soon. " + "Please use explicit op/thunk creation methods instead " + "(e.g. make_ops_and_vars, etc).", + DeprecationWarning) + (self._cov_update_ops, self._cov_update_op, self._inv_update_ops, + self._inv_update_op, self._cov_update_thunks, + self._inv_update_thunks) = self.make_ops_and_vars() - @property - def damping(self): - return self._fisher_est.damping + def make_ops_and_vars(self): + """Make ops and vars with device placement `self._placement_strategy`. + + See `FisherEstimator.make_ops_and_vars` for details. + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_op: inv_update_ops grouped into a single op. + """ + return self._fisher_est.make_ops_and_vars(scope=self.get_name()) + + def make_vars_and_create_op_thunks(self): + """Make vars and create op thunks. + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) + + def create_ops_and_vars_thunks(self): + """Create thunks that make the ops and vars on demand. + + This function returns 4 lists of thunks: cov_variable_thunks, + cov_update_thunks, inv_variable_thunks, and inv_update_thunks. + + The length of each list is the number of factors and the i-th element of + each list corresponds to the i-th factor (given by the "factors" property). + + Note that the execution of these thunks must happen in a certain + partial order. The i-th element of cov_variable_thunks must execute + before the i-th element of cov_update_thunks (and also the i-th element + of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks + must execute before the i-th element of inv_update_thunks. + + TL;DR (oversimplified): Execute the thunks according to the order that + they are returned. + + Returns: + cov_variable_thunks: A list of thunks that make the cov variables. + cov_update_thunks: A list of thunks that make the cov update ops. + inv_variable_thunks: A list of thunks that make the inv variables. + inv_update_thunks: A list of thunks that make the inv update ops. + """ + scope = self.get_name() + "/" + self._fisher_est.name + return self._fisher_est.create_ops_and_vars_thunks(scope=scope) def minimize(self, *args, **kwargs): - kwargs["var_list"] = kwargs.get("var_list") or self.variables - if set(kwargs["var_list"]) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + # Should this variable scope encompass everything below? Or will the super- + # class make another copy of the same name scope? + with variable_scope.variable_scope(self.get_name()): + kwargs["var_list"] = kwargs.get("var_list") or self.variables + if set(kwargs["var_list"]) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + if self._adapt_damping and self._is_chief: + global_step = kwargs.get("global_step", None) + if not global_step: + raise KeyError("global_step needs to be passed to optimizer.minimize " + "if damping parameter is adapted.") + update_damping_op = self._update_damping(self._prev_train_batch, + global_step) + with ops.control_dependencies([update_damping_op]): + loss = args[0] + loss_assign_op = state_ops.assign(self._prev_loss, loss) + train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) + return control_flow_ops.group(loss_assign_op, train_op) + else: + return super(KfacOptimizer, self).minimize(*args, **kwargs) def compute_gradients(self, *args, **kwargs): # args[1] could be our var_list @@ -185,6 +368,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): else: kwargs["var_list"] = kwargs.get("var_list") or self.variables var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") @@ -201,6 +385,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): Returns: An `Operation` that applies the specified gradients. """ + self._maybe_make_and_save_everything() # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. @@ -296,6 +481,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] + def _compute_prev_updates(self, variables): + """Computes previous updates as negative velocities scaled by learning rate. + + Args: + variables: List of variables in the graph that the update will be + applied to. + + Returns: + List of previous updates applied to the `variables`. + """ + return list( + -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) + for var in variables) + def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, variables): """Compute optimal update hyperparameters from the quadratic model. @@ -336,12 +535,12 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). """ - cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables) + cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, + variables) # compute the matrix-vector products with the transposed Fisher factor fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( self._batch_size, dtype=fft_precon_grads[0].dtype) @@ -374,9 +573,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], [_inner_product_list(grads, prev_updates)]]) - sol = _two_by_two_solve(m, c) - alpha = -sol[0] - mu = -sol[1] + sol = -1. * _two_by_two_solve(m, c) + alpha = sol[0] + mu = sol[1] qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) return alpha, mu, qmodel_change @@ -404,6 +603,52 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return control_flow_ops.cond( math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) + def _assign_q_model_change(self, q_model_change): + """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. + + Note only the chief worker does the assignment. + + Args: + q_model_change: Scalar tensor of type `float32`. + + Returns: + If `adapt_damping` is `True` then returns an assign op, Otherwise returns + a no_op(). + """ + if self._adapt_damping and self._is_chief: + q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) + else: + q_model_assign_op = control_flow_ops.no_op() + return q_model_assign_op + + def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, + precon_grads_and_vars): + """Wrapper function for `self._compute_qmodel_hyperparams`. + + Constructs a list of preconditioned gradients and variables. Also creates a + op to asssign the computed q model change to `self._q_model_change`. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradients, variable) + pairs. + + Returns: + (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize + the quadratic model, `q_model_assign_op` assigns the computed q model + change to `self._q_model_change`. + """ + precon_grads = list( + precon_grad for (precon_grad, _) in precon_grads_and_vars) + grads = list(grad for (grad, _) in grads_and_vars) + variables = list(var for (_, var) in grads_and_vars) + prev_updates = self._compute_prev_updates(variables) + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_change = self._compute_qmodel_hyperparams( + precon_grads, prev_updates, grads, variables) + + return alpha, mu, self._assign_q_model_change(q_model_change) + def _compute_update_steps(self, grads_and_vars): """Computes the update steps for the variables given the gradients. @@ -411,8 +656,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): grads_and_vars: List of (gradient, variable) pairs. Returns: - An 'Operation that computes the update steps for the given variables. + A list of tuple (assign_op ,var) where `assign_op` assigns the update + steps to `var`. """ + if self._momentum_type == "regular": # Compute "preconditioned" gradient. precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) @@ -423,8 +670,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): precon_grads_and_vars) # Update the velocity with this and return it as the step. - return self._update_velocities(precon_grads_and_vars, self._momentum) - + if self._adapt_damping and self._is_chief: + _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities(precon_grads_and_vars, self._momentum) + else: + return self._update_velocities(precon_grads_and_vars, self._momentum) elif self._momentum_type == "adam": # Update velocity. velocities_and_vars = self._update_velocities(grads_and_vars, @@ -436,23 +688,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Compute "preconditioned" gradient. precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - # Extract out singleton lists from the tuple-lists - precon_grads = list( - precon_grad for (precon_grad, _) in precon_grads_and_vars) - grads = list(grad for (grad, _) in grads_and_vars) - variables = list(var for (_, var) in grads_and_vars) - # previous updates are the negative velocities (up to scaling by LR) - prev_updates = list( - -self._zeros_slot(var, "velocity", self._name) for var in variables) - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, _ = self._compute_qmodel_hyperparams( - precon_grads, prev_updates, grads, variables) + alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) - # Update the velocity with precon_grads according to these params - # and return it as the step. - return self._update_velocities( - precon_grads_and_vars, mu, vec_coeff=-alpha) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities( + precon_grads_and_vars, mu, vec_coeff=-alpha) def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): """Updates the velocities of the variables with the given vectors. @@ -482,6 +724,50 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] + def _update_damping(self, prev_batch, global_step): + """Adapts damping parameter. Check KFAC (Section 6.5) for the details. + + The damping parameter is updated according to the Levenberg-Marquardt rule + every `self._damping_adaptation_interval` iterations. + + Args: + prev_batch: Tensor or tuple of tensors which can be passed to + `self._loss_fn` to evaluate loss. + global_step: `Variable` which keeps track of number of times the training + variables have been updated. + Returns: + A `tf.cond` op which updates the damping parameter. + """ + def compute_damping(): + """"Adapts damping parameter based on "reduction ratio". + + Reduction ratio captures how closely the quadratic approximation to the + loss function approximates the actual loss within a trust region. The + damping update tries to make the damping as small as possible while + maintaining the property that the quadratic model remains a good local + approximation to the loss function. + + Returns: + An Op to assign newly computed damping value to `self._damping`. + """ + prev_batch_loss = self._loss_fn(prev_batch) + with ops.control_dependencies([prev_batch_loss]): + rho_assign = self._rho.assign( + (prev_batch_loss - self._prev_loss) / self._q_model_change) + with ops.control_dependencies([rho_assign]): + new_damping = control_flow_ops.case( + [(self._rho < 0.25, lambda: self.damping / self._omega), + (self._rho > 0.75, lambda: self.damping * self._omega)], + lambda: self.damping) + with ops.control_dependencies([new_damping]): + new_damping_min = math_ops.maximum(new_damping, self._min_damping) + return control_flow_ops.group(self._damping.assign(new_damping_min)) + + return control_flow_ops.cond( + math_ops.equal( + math_ops.mod(global_step + 1, self._damping_adaptation_interval), + 0), compute_damping, control_flow_ops.no_op) + def _inner_product_list(list1, list2): return math_ops.add_n( diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py new file mode 100644 index 0000000000000000000000000000000000000000..bf12dbaa9adbaa4af1511034aef0b5ab59d53e26 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements placement strategies for cov and inv ops, cov variables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope + + +def _make_thunk_on_device(func, device): + def thunk(): + with tf_ops.device(device): + return func() + return thunk + + +class RoundRobinPlacementMixin(object): + """Implements round robin placement strategy for ops and variables.""" + + def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + """Initializes the RoundRobinPlacementMixin class. + + Args: + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + *args: + **kwargs: + + """ + super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + self._cov_devices = cov_devices + self._inv_devices = inv_devices + + def make_ops_and_vars(self, scope=None): + """Make ops and vars with a round-robin device placement strategy. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `None` then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all ops will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_ops: List of ops that compute the cov updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + cov_update_op: cov_update_ops grouped into a single op. + inv_update_ops: List of ops that compute the inv updates. Corresponds + one-to-one with the list of factors given by the "factors" property. + inv_update_op: inv_update_ops grouped into a single op. + cov_update_thunks: Thunks that make the ops in cov_update_ops. + inv_update_thunks: Thunks that make the ops in inv_update_ops. + """ + (cov_update_thunks, + inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope) + cov_update_ops = [thunk() for thunk in cov_update_thunks] + inv_update_ops = [thunk() for thunk in inv_update_thunks] + + scope = self.name if scope is None else scope + with variable_scope.variable_scope(scope): + cov_update_op = control_flow_ops.group(cov_update_ops, + name="cov_update_op") + inv_update_op = control_flow_ops.group(inv_update_ops, + name="inv_update_op") + + return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, + cov_update_thunks, inv_update_thunks) + + def make_vars_and_create_op_thunks(self, scope=None): + """Make vars and create op thunks w/ a round-robin device placement strat. + + For each factor, all of that factor's cov variables and their associated + update ops will be placed on a particular device. A new device is chosen + for each factor by cycling through list of devices in the + `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no + explicit device placement occurs. + + An analogous strategy is followed for inverse update ops, with the list of + devices being given by the `self._inv_devices` attribute. + + Inverse variables on the other hand are not placed on any specific device + (they will just use the current the device placement context, whatever + that happens to be). The idea is that the inverse variable belong where + they will be accessed most often, which is the device that actually applies + the preconditioner to the gradient. The user will be responsible for setting + the device context for this. + + Args: + scope: A string or None. If None it will be set to the name of this + estimator (given by the name property). All variables will be created, + and all thunks will execute, inside of a variable scope of the given + name. (Default: None) + + Returns: + cov_update_thunks: List of cov update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + inv_update_thunks: List of inv update thunks. Corresponds one-to-one with + the list of factors given by the "factors" property. + """ + # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. + (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, + inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) + + if self._cov_devices: + cov_update_thunks = [] + for cov_variable_thunk, cov_update_thunk, device in zip( + cov_variable_thunks_raw, cov_update_thunks_raw, + itertools.cycle(self._cov_devices)): + with tf_ops.device(device): + cov_variable_thunk() + cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, + device)) + else: + for cov_variable_thunk in cov_variable_thunks_raw: + cov_variable_thunk() + cov_update_thunks = cov_update_thunks_raw + + for inv_variable_thunk in inv_variable_thunks_raw: + inv_variable_thunk() + + if self._inv_devices: + inv_update_thunks = [] + for inv_update_thunk, device in zip(inv_update_thunks_raw, + itertools.cycle(self._inv_devices)): + inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, + device)) + else: + inv_update_thunks = inv_update_thunks_raw + + return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index e89508fa46b6e2ce278e5373e6c9d17203ad1ef2..b6f42815e79fa5eb9c6a2aa9f99ac3ec5a70ad0a 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -24,11 +24,13 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables @@ -144,7 +146,9 @@ def layer_params_to_mat2d(vector): [-1, w_part.shape.as_list()[-1]]) return array_ops.concat( (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) - else: + elif isinstance(vector, ops.IndexedSlices): + return vector + else: # Tensor or Tensor-like. return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) @@ -163,6 +167,11 @@ def mat2d_to_layer_params(vector_template, mat2d): if isinstance(vector_template, (tuple, list)): w_part, b_part = mat2d[:-1], mat2d[-1] return array_ops.reshape(w_part, vector_template[0].shape), b_part + elif isinstance(vector_template, ops.IndexedSlices): + if not isinstance(mat2d, ops.IndexedSlices): + raise TypeError( + "If vector_template is an IndexedSlices, so should mat2d.") + return mat2d else: return array_ops.reshape(mat2d, vector_template.shape) @@ -234,19 +243,22 @@ class SubGraph(object): # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() - self._recurse_add(outputs) + self._iter_add(outputs) - def _recurse_add(self, nodes): - """Recursively adds all of nodes' ancestors.""" - for node in nodes: - if node in self._members: - continue - self._members.add(node) + def _iter_add(self, root): + """Iteratively adds all of nodes' ancestors using depth first search.""" + stack = [root] + while stack: + nodes = stack.pop() + for node in nodes: + if node in self._members: + continue + self._members.add(node) - if isinstance(node, ops.Tensor): - self._recurse_add((node.op,)) - elif isinstance(node, ops.Operation): - self._recurse_add(node.inputs) + if isinstance(node, ops.Tensor): + stack.append((node.op,)) + elif isinstance(node, ops.Operation): + stack.append(node.inputs) def is_member(self, node): """Check if 'node' is in this subgraph.""" @@ -420,5 +432,266 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result +def extract_convolution_patches(inputs, + filter_shape, + padding, + strides=None, + dilation_rate=None, + name=None, + data_format=None): + """Extracts inputs to each output coordinate in tf.nn.convolution. + + This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), + where the number of spatial dimensions may be something other than 2. + + Assumes, + - First dimension of inputs is batch_size + - Convolution filter is applied to all input channels. + + Args: + inputs: Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). + filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). + padding: string. Padding method. One of "VALID", "SAME". + strides: None or list of ints. Strides along spatial dimensions. + dilation_rate: None or list of ints. Dilation along spatial dimensions. + name: None or str. Name of Op. + data_format: None or str. Format of data. + + Returns: + Tensor of shape [batch_size, ..spatial_image_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: If data_format does not put channel last. + ValueError: If inputs and filter disagree on in_channels. + """ + if not is_data_format_channel_last(data_format): + raise ValueError("Channel must be last dimension.") + with ops.name_scope(name, "extract_convolution_patches", + [inputs, filter_shape, padding, strides, dilation_rate]): + batch_size = inputs.shape.as_list()[0] + in_channels = inputs.shape.as_list()[-1] + + # filter_shape = spatial_filter_shape + [in_channels, out_channels] + spatial_filter_shape = filter_shape[:-2] + if in_channels != filter_shape[-2]: + raise ValueError("inputs and filter_shape must agree on in_channels.") + + # Map each input feature to a location in the output. + out_channels = np.prod(spatial_filter_shape) * in_channels + filters = linalg_ops.eye(out_channels) + filters = array_ops.reshape( + filters, + list(spatial_filter_shape) + [in_channels, out_channels]) + + result = nn_ops.convolution( + inputs, + filters, + padding=padding, + strides=strides, + dilation_rate=dilation_rate) + spatial_output_shape = result.shape.as_list()[1:-1] + result = array_ops.reshape(result, + [batch_size or -1] + spatial_output_shape + + list(spatial_filter_shape) + [in_channels]) + + return result + + +def extract_pointwise_conv2d_patches(inputs, + filter_shape, + name=None, + data_format=None): + """Extract patches for a 1x1 conv2d. + + Args: + inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. + filter_shape: List of 4 ints. Shape of filter to apply with conv2d() + name: None or str. Name for Op. + data_format: None or str. Format for data. See 'data_format' in + tf.nn.conv2d() for details. + + Returns: + Tensor of shape [batch_size, ..spatial_input_shape.., + ..spatial_filter_shape.., in_channels] + + Raises: + ValueError: if inputs is not 4-D. + ValueError: if filter_shape is not [1, 1, ?, ?] + ValueError: if data_format is not channels-last. + """ + if inputs.shape.ndims != 4: + raise ValueError("inputs must have 4 dims.") + if len(filter_shape) != 4: + raise ValueError("filter_shape must have 4 dims.") + if filter_shape[0] != 1 or filter_shape[1] != 1: + raise ValueError("filter_shape must have shape 1 along spatial dimensions.") + if not is_data_format_channel_last(data_format): + raise ValueError("data_format must be channels last.") + with ops.name_scope(name, "extract_pointwise_conv2d_patches", + [inputs, filter_shape]): + ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. + strides = [1, 1, 1, 1] # Operate on all pixels. + rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. + padding = "VALID" # Doesn't matter. + result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, + padding) + + batch_size, input_height, input_width, in_channels = inputs.shape.as_list() + filter_height, filter_width, in_channels, _ = filter_shape + return array_ops.reshape(result, [ + batch_size, input_height, input_width, filter_height, filter_width, + in_channels + ]) + + +def is_data_format_channel_last(data_format): + """True if data_format puts channel last.""" + if data_format is None: + return True + return data_format.endswith("C") + + +def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is sparse, B is dense. + + Args: + A: tf.IndexedSlices with dense shape [m, n]. + B: tf.Tensor with shape [n, k]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A doesn't represent a matrix. + ValueError: If B is not rank-2. + """ + with ops.name_scope(name, "matmul_sparse_dense", [A, B]): + if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: + raise ValueError("A must represent a matrix. Found: %s." % A) + if B.shape.ndims != 2: + raise ValueError("B must be a matrix.") + new_values = math_ops.matmul(A.values, B) + return ops.IndexedSlices( + new_values, + A.indices, + dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) + + +def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. + + Args: + A_diag: diagonal entries of matrix A of shape [m, m]. + B: tf.IndexedSlices. Represents matrix of shape [m, n]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A_diag is not rank-1. + ValueError: If B doesn't represent a matrix. + """ + with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): + A_diag = ops.convert_to_tensor(A_diag) + if A_diag.shape.ndims != 1: + raise ValueError("A_diag must be a rank-1 Tensor.") + if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: + raise ValueError("B must represent a matrix. Found: %s." % B) + a = array_ops.gather(A_diag, B.indices) + a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) + return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + + +class PartitionedTensor(object): + """A Tensor partitioned across its 0-th dimension.""" + + def __init__(self, tensors): + """Initializes PartitionedTensor. + + Args: + tensors: List of Tensors. All Tensors must agree on shape (excepting + batch dimension) and dtype. + + Raises: + ValueError: If 'tensors' has length zero. + ValueError: if contents of 'tensors' don't agree on shape or dtype. + """ + if not tensors: + raise ValueError("tensors must be a list of 1+ Tensors.") + + dtype = tensors[0].dtype + if not all(tensor.dtype == dtype for tensor in tensors): + raise ValueError("all tensors must have dtype = %s." % dtype) + + shape = tensors[0].shape[1:] + if not all(tensor.shape[1:] == shape for tensor in tensors): + raise ValueError("All tensors must have shape = %s (excluding batch " + "dimension)." % shape) + + self.tensors = tensors + self._concats = {} # {device: Tensor} + + @property + def shape(self): + feature_shape = self.tensors[0].shape[1:] + batch_size = sum([tensor.shape[0] for tensor in self.tensors], + tensor_shape.Dimension(0)) + return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) + + def get_shape(self): + return self.shape + + @property + def dtype(self): + return self.tensors[0].dtype + + def __str__(self): + return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( + self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) + + def __hash__(self): + return hash(tuple(self.tensors)) + + def __eq__(self, other): + if not isinstance(other, PartitionedTensor): + return False + return self.tensors == other.tensors + + def __ne__(self, other): + return not self == other # pylint: disable=g-comparison-negation + + def __getitem__(self, key): + return self.as_tensor()[key] + + def as_tensor(self, dtype=None, name=None, as_ref=False): + with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): + assert not as_ref + assert dtype in [None, self.dtype] + result = array_ops.concat(self.tensors, axis=0) + + # Cache 'result' if we haven't already cached a value for this device. + if result.device not in self._concats: + self._concats[result.device] = result + return self._concats[result.device] + + @property + def device(self): + # PartitionedTensors in general do not live on a single device. If the + # device cannot be determined unambiguously this property will return None. + device = self.tensors[0].device + if all(tensor.device == device for tensor in self.tensors): + return device + return None + + +ops.register_tensor_conversion_function( + PartitionedTensor, + lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index fe8e39c212c2c3381f9aa6fdb9fdf423ff958481..330d222dbf70fcfa02ffd47261c0513d9dd6e0e9 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -40,6 +40,11 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "extract_convolution_patches", + "extract_pointwise_conv2d_patches", + "is_data_format_channel_last", + "matmul_sparse_dense", + "matmul_diag_sparse", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 894e6f6946bb59810a9da2d304cc0dd43d25201d..c8812d4b23f94102d093db878a709b090a3318d6 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -70,6 +70,7 @@ py_test( "python/ops/core_test.py", ], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":_typecheck", ":core", @@ -213,14 +214,3 @@ py_test( "//tensorflow/python:math_ops", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py index abc18aa123bb4d40b54d22ec03257c5350118d13..0c6bba758b429a8c4112bc6abb2fae542b5dfc14 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -361,6 +361,10 @@ class LabeledTensor(object): def dtype(self): return self._tensor.dtype + @property + def shape(self): + return self._tensor.shape + @property def name(self): return self._tensor.name diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py index e70b4923749d89aba1bd0187857d762305daeb07..e378db56afb1d4f9463d2c9b0f1fa4c0feea8fb0 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py @@ -244,6 +244,9 @@ class LabeledTensorTest(test_util.Base): def test_dtype(self): self.assertEqual(self.lt.dtype, self.lt.tensor.dtype) + def test_shape(self): + self.assertEqual(self.lt.shape, self.lt.tensor.shape) + def test_get_shape(self): self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape()) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index c957b41a49b292225e547ce17b0c5a247810325a..3ba1026383ef146adb32197ae41b5c251155bf46 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -951,7 +951,7 @@ def define_reduce_op(op_name, reduce_fn): intermediate_axes.append(axis) reduce_op = reduce_fn( - labeled_tensor.tensor, reduction_dimensions, keep_dims=True) + labeled_tensor.tensor, reduction_dimensions, keepdims=True) reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes) return squeeze(reduce_lt, axes_to_squeeze, name=scope) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py index 0727f4cf88728dc3d919e662d65c93a658ac730b..39e9d65407f3b1e79804317023ea03dd81484ff5 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py @@ -660,7 +660,7 @@ class ReduceSumTest(Base): sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')}) golden_lt = core.LabeledTensor( math_ops.reduce_sum( - self.original_lt.tensor, 1, keep_dims=True), + self.original_lt.tensor, 1, keepdims=True), [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) self.assertLabeledTensorsEqual(sum_lt, golden_lt) @@ -668,7 +668,7 @@ class ReduceSumTest(Base): sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou')) golden_lt = core.LabeledTensor( math_ops.reduce_sum( - self.original_lt.tensor, 1, keep_dims=True), + self.original_lt.tensor, 1, keepdims=True), [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) self.assertLabeledTensorsEqual(sum_lt, golden_lt) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 852d06e1e3cc8f8deecd15b7436cd4e4a393ad66..d5b3b279a1b7327602790c0260349cb0c758aa86 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -188,6 +188,7 @@ py_test( size = "small", srcs = ["python/layers/normalization_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":layers_py", "//tensorflow/contrib/framework:framework_py", @@ -353,6 +354,7 @@ py_test( size = "small", srcs = ["python/ops/sparse_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":layers_py", "//tensorflow/python:array_ops", @@ -390,15 +392,3 @@ py_test( "//tensorflow/python:variables", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 337c9e06b870b2cca53fcdbf3d94225660e193c4..00f03a111ae8be7f49761ef5fb5a82810bcca182 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -104,6 +104,7 @@ See the @{$python/contrib.layers} guide. @@infer_real_valued_columns @@sequence_input_from_feature_columns +@@group_norm @@instance_norm """ @@ -122,6 +123,7 @@ _allowed_symbols = ['bias_add', 'conv3d', 'elu', 'feature_column', + 'group_norm', 'instance_norm', 'legacy_fully_connected', 'legacy_linear', diff --git a/tensorflow/contrib/layers/kernels/BUILD b/tensorflow/contrib/layers/kernels/BUILD index e407a9ce015603094c7bbab72856403e2f0eb1a1..7aae09ff3e9995b2d92b05211b3bf8a94a26ff43 100644 --- a/tensorflow/contrib/layers/kernels/BUILD +++ b/tensorflow/contrib/layers/kernels/BUILD @@ -18,14 +18,3 @@ cc_library( ], alwayslink = 1, ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py index f701647c2b297015f025eb53bd191a1a8c54ec62..28ddaa69a14776e0c157c2e68105ee9e17bc3cbb 100644 --- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py +++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py @@ -200,7 +200,7 @@ class SparseCrossOpTest(test.TestCase): self._assert_sparse_tensor_equals(expected_out, sess.run(op)) def test_large_batch(self): - """Tests with large batch size to force multithreding. + """Tests with large batch size to force multithreading. """ batch_size = 5000 col1 = [] diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index b62e3050cd7003f1ba72061b133ff9b5d6b616da..49c3faf3b7f5eaa3b1542a1fdddcfaff99737a24 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -140,6 +140,9 @@ def safe_embedding_lookup_sparse(embedding_weights, # Prune invalid ids and weights. sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights) + if combiner != "sum": + sparse_ids, sparse_weights = _prune_invalid_weights( + sparse_ids, sparse_weights) # Fill in dummy values for empty features, if necessary. sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids, @@ -188,13 +191,23 @@ def _prune_invalid_ids(sparse_ids, sparse_weights): is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) if sparse_weights is not None: is_id_valid = math_ops.logical_and( - is_id_valid, math_ops.greater(sparse_weights.values, 0)) + is_id_valid, + array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) if sparse_weights is not None: sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) return sparse_ids, sparse_weights +def _prune_invalid_weights(sparse_ids, sparse_weights): + """Prune invalid weights (< 0) from the input ids and weights.""" + if sparse_weights is not None: + is_weights_valid = math_ops.greater(sparse_weights.values, 0) + sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) + sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) + return sparse_ids, sparse_weights + + def scattered_embedding_lookup(params, values, dimension, @@ -470,7 +483,7 @@ def embedding_lookup_unique(params, ids, name=None): ids = ops.convert_to_tensor(ids) shape = array_ops.shape(ids) ids_flat = array_ops.reshape( - ids, math_ops.reduce_prod(shape, keep_dims=True)) + ids, math_ops.reduce_prod(shape, keepdims=True)) unique_ids, idx = array_ops.unique(ids_flat) unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids) embeds_flat = array_ops.gather(unique_embeddings, idx) diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py index 89c9d37bd09cb6c43eebb91f3a16600eae9cb490..f42112206d0db9d2e42bd4cff19f6a6533951d46 100644 --- a/tensorflow/contrib/layers/python/layers/encoders.py +++ b/tensorflow/contrib/layers/python/layers/encoders.py @@ -125,7 +125,7 @@ def embed_sequence(ids, `reuse` is `None` or `False`. """ if not (reuse or (vocab_size and embed_dim)): - raise ValueError('Must specify vocab size and embedding dimension when not' + raise ValueError('Must specify vocab size and embedding dimension when not ' 'reusing. Got vocab_size=%s and embed_dim=%s' % ( vocab_size, embed_dim)) with variable_scope.variable_scope( diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index b7d34d6435789e54403926a342481971e854b449..3ae07cedab0be2da8ec633cfd84e07cfdfb11457 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -48,7 +48,7 @@ you should choose depends on (1) the feature type and (2) the model type. recommended. embedded_dept_column = embedding_column( - sparse_column_with_keys("department", ["math", "philosphy", ...]), + sparse_column_with_keys("department", ["math", "philosophy", ...]), dimension=10) * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). @@ -154,6 +154,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +from tensorflow.python.util import nest # Imports the core `InputLayer` symbol in contrib during development. @@ -554,28 +555,70 @@ def sparse_column_with_integerized_feature(column_name, class _SparseColumnHashed(_SparseColumn): """See `sparse_column_with_hash_bucket`.""" + def __new__(cls, + column_name, + is_integerized=False, + bucket_size=None, + lookup_config=None, + combiner="sum", + dtype=dtypes.string, + hash_keys=None): + if hash_keys is not None: + if not isinstance(hash_keys, list) or not hash_keys: + raise ValueError("hash_keys must be a non-empty list.") + if (any([not isinstance(key_pair, list) for key_pair in hash_keys]) or + any([len(key_pair) != 2 for key_pair in hash_keys]) or + any([not isinstance(key, int) for key in nest.flatten(hash_keys)])): + raise ValueError( + "Each element of hash_keys must be a pair of integers.") + obj = super(_SparseColumnHashed, cls).__new__( + cls, + column_name, + is_integerized=is_integerized, + bucket_size=bucket_size, + lookup_config=lookup_config, + combiner=combiner, + dtype=dtype) + obj.hash_keys = hash_keys + return obj + def _do_transform(self, input_tensor): if self.dtype.is_integer: sparse_values = string_ops.as_string(input_tensor.values) else: sparse_values = input_tensor.values - sparse_id_values = string_ops.string_to_hash_bucket_fast( - sparse_values, self.bucket_size, name="lookup") - return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values, - input_tensor.dense_shape) + if self.hash_keys: + result = [] + for key in self.hash_keys: + sparse_id_values = string_ops.string_to_hash_bucket_strong( + sparse_values, self.bucket_size, key) + result.append( + sparse_tensor_py.SparseTensor(input_tensor.indices, + sparse_id_values, + input_tensor.dense_shape)) + return sparse_ops.sparse_concat(axis=1, sp_inputs=result, name="lookup") + else: + sparse_id_values = string_ops.string_to_hash_bucket_fast( + sparse_values, self.bucket_size, name="lookup") + return sparse_tensor_py.SparseTensor( + input_tensor.indices, sparse_id_values, input_tensor.dense_shape) def sparse_column_with_hash_bucket(column_name, hash_bucket_size, combiner="sum", - dtype=dtypes.string): + dtype=dtypes.string, + hash_keys=None): """Creates a _SparseColumn with hashed bucket configuration. Use this when your sparse features are in string or integer format, but you don't have a vocab file that maps each value to an integer ID. output_id = Hash(input_feature_string) % bucket_size + When hash_keys is set, multiple integer IDs would be created with each key + pair in the `hash_keys`. This is useful to reduce the collision of hashed ids. + Args: column_name: A string defining sparse column name. hash_bucket_size: An int that is > 1. The number of buckets. @@ -588,6 +631,9 @@ def sparse_column_with_hash_bucket(column_name, * "sqrtn": do l2 normalization on features in the column For more information: `tf.embedding_lookup_sparse`. dtype: The type of features. Only string and integer types are supported. + hash_keys: The hash keys to use. It is a list of lists of two uint64s. If + None, simple and fast hashing algorithm is used. Otherwise, multiple + strong hash ids would be produced with each two unit64s in this argument. Returns: A _SparseColumn with hashed bucket configuration @@ -600,7 +646,8 @@ def sparse_column_with_hash_bucket(column_name, column_name, bucket_size=hash_bucket_size, combiner=combiner, - dtype=dtype) + dtype=dtype, + hash_keys=hash_keys) class _SparseColumnKeys(_SparseColumn): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 78affea44cbfb92523063968dbc1be98841854db..06060b99e7e58787994f20f037ffa451abbc7459 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -815,7 +815,7 @@ class _Transformer(object): """ def __init__(self, columns_to_tensors): - """Initializes transfomer. + """Initializes transformer. Args: columns_to_tensors: A mapping from feature columns to tensors. 'string' @@ -908,7 +908,7 @@ def _gather_feature_columns(feature_columns): def _check_forbidden_sequence_columns(feature_columns): - """Recursively cecks `feature_columns` for `_FORBIDDEN_SEQUENCE_COLUMNS`.""" + """Recursively checks `feature_columns` for `_FORBIDDEN_SEQUENCE_COLUMNS`.""" all_feature_columns = _gather_feature_columns(feature_columns) for feature_column in all_feature_columns: if isinstance(feature_column, _FORBIDDEN_SEQUENCE_COLUMNS): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index fc8f153fe3abdc83aca5abfa9a4bb5f5d5531480..1de9ab705655db9863d9c7d2630f24283c83d44d 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -329,6 +329,55 @@ class FeatureColumnTest(test.TestCase): self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights") self.assertEqual(one_hot.length, 3) + def testOneHotColumnWithSparseColumnWithHashKeys(self): + input_values = ["marlo", "unknown", "omar"] + inputs = constant_op.constant(input_values) + hash_keys = [[10, 20], [20, 30]] + hash_column = fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=10, hash_keys=hash_keys) + columns_to_tensors = {} + columns_to_tensors["ids"] = inputs + hash_column.insert_transformed_feature(columns_to_tensors) + self.assertEqual(len(columns_to_tensors), 2) + self.assertTrue(hash_column in columns_to_tensors) + + one_hot_column = fc.one_hot_column(hash_column) + one_hot_output = one_hot_column._to_dnn_input_layer( + columns_to_tensors[hash_column]) + + expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0., + 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.], + [1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]) + with self.test_session() as sess: + one_hot_value = sess.run(one_hot_output) + self.assertTrue(np.array_equal(one_hot_value, expected)) + + def testSparseColumnWithHashKeysWithUnexpectedHashKeys(self): + with self.assertRaisesRegexp(ValueError, + "hash_keys must be a non-empty list."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[]) + + with self.assertRaisesRegexp(ValueError, + "hash_keys must be a non-empty list."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=1) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[1, 2]) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=["key"]) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[[1, 2.0]]) + def testMissingValueInOneHotColumnForWeightedSparseColumn(self): # Github issue 12583 ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 5c1ff9ec267f1bccd9bee44a4b19e7ed3ec24cf0..151fc7a0d734fe8ea4d7872a4051e82d317a500e 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -51,7 +51,6 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import moving_averages -from tensorflow.python.layers.maxout import maxout # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. @@ -933,7 +932,8 @@ def convolution(inputs, variables_collections=None, outputs_collections=None, trainable=True, - scope=None): + scope=None, + conv_dims=None): """Adds an N-D convolution followed by an optional batch_norm layer. It is required that 1 <= N <= 3. @@ -994,6 +994,10 @@ def convolution(inputs, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for `variable_scope`. + conv_dims: Optional convolution dimensionality, when set it would use the + corresponding convolution (e.g. 2 for Conv 2D, 3 for Conv 3D, ..). When + leaved to None it would select the convolution dimensionality based on + the input rank (i.e. Conv ND, with N = input_rank - 2). Returns: A tensor representing the output of the operation. @@ -1016,6 +1020,9 @@ def convolution(inputs, inputs = ops.convert_to_tensor(inputs) input_rank = inputs.get_shape().ndims + if conv_dims is not None and conv_dims + 2 != input_rank: + raise ValueError('Convolution expects input with rank %d, got %d' % + (conv_dims + 2, input_rank)) if input_rank == 3: layer_class = convolutional_layers.Convolution1D elif input_rank == 4: @@ -1062,10 +1069,134 @@ def convolution(inputs, outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def convolution1d(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + return convolution(inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=1) + +convolution1d.__doc__ = convolution.__doc__ -convolution2d = convolution -convolution3d = convolution +@add_arg_scope +def convolution2d(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + return convolution(inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=2) + +convolution2d.__doc__ = convolution.__doc__ +@add_arg_scope +def convolution3d(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + return convolution(inputs, + num_outputs, + kernel_size, + stride, + padding, + data_format, + rate, + activation_fn, + normalizer_fn, + normalizer_params, + weights_initializer, + weights_regularizer, + biases_initializer, + biases_regularizer, + reuse, + variables_collections, + outputs_collections, + trainable, + scope, + conv_dims=3) + +convolution3d.__doc__ = convolution.__doc__ @add_arg_scope def convolution2d_in_plane( @@ -1411,7 +1542,7 @@ def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): Args: tensor: An `int` `Tensor` to be converted to a `Sparse`. eos_token: An integer. - It is part of the target label that signfies the end of a sentence. + It is part of the target label that signifies the end of a sentence. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. """ @@ -1555,7 +1686,7 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): output_collections: Collection to which the outputs will be added. scope: Optional scope for `name_scope`. Returns: - A `Tensor` or `SparseTensor` conataining the same values as `inputs`, but + A `Tensor` or `SparseTensor` containing the same values as `inputs`, but with innermost dimensions flattened to obtain rank `new_rank`. Raises: @@ -2187,8 +2318,10 @@ def layer_norm(inputs, @add_arg_scope -def images_to_sequence(inputs, data_format=DATA_FORMAT_NHWC, - outputs_collections=None, scope=None): +def images_to_sequence(inputs, + data_format=DATA_FORMAT_NHWC, + outputs_collections=None, + scope=None): """Convert a batch of images into a batch of sequences. Args: inputs: a (num_images, height, width, depth) tensor @@ -2694,8 +2827,11 @@ def separable_convolution2d( @add_arg_scope -def sequence_to_images(inputs, height, output_data_format='channels_last', - outputs_collections=None, scope=None): +def sequence_to_images(inputs, + height, + output_data_format='channels_last', + outputs_collections=None, + scope=None): """Convert a batch of sequences into a batch of images. Args: inputs: (num_steps, num_batches, depth) sequence tensor @@ -2743,7 +2879,7 @@ def softmax(logits, scope=None): logits_2d = array_ops.reshape(logits, [-1, num_logits]) predictions = nn.softmax(logits_2d) predictions = array_ops.reshape(predictions, array_ops.shape(logits)) - if context.in_graph_mode(): + if not context.executing_eagerly(): predictions.set_shape(logits.get_shape()) return predictions @@ -2936,6 +3072,53 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None): return math_ops.div(inputs, array_ops.tile(lengths, multiples)) +@add_arg_scope +def maxout(inputs, num_units, axis=-1, scope=None): + """Adds a maxout op from https://arxiv.org/abs/1302.4389 + + "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron + Courville, + Yoshua Bengio + + Usually the operation is performed in the filter/channel dimension. This can + also be + used after fully-connected layers to reduce number of features. + + Arguments: + inputs: Tensor input + num_units: Specifies how many features will remain after maxout + in the `axis` dimension (usually channel). + This must be multiple of number of `axis`. + axis: The dimension where max pooling will be performed. Default is the + last dimension. + scope: Optional scope for variable_scope. + + Returns: + A `Tensor` representing the results of the pooling operation. + + Raises: + ValueError: if num_units is not multiple of number of features. + """ + with variable_scope.variable_scope(scope, 'MaxOut', [inputs]): + inputs = ops.convert_to_tensor(inputs) + shape = inputs.get_shape().as_list() + num_channels = shape[axis] + if num_channels % num_units: + raise ValueError('number of features({}) is not ' + 'a multiple of num_units({})'.format( + num_channels, num_units)) + shape[axis] = -1 + shape += [num_channels // num_units] + + # Dealing with batches with arbitrary sizes + for i in range(len(shape)): + if shape[i] is None: + shape[i] = array_ops.shape(inputs)[i] + outputs = math_ops.reduce_max( + array_ops.reshape(inputs, shape), -1, keepdims=False) + return outputs + + def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): """Project into the Poincare ball with norm <= 1.0 - epsilon. @@ -2994,16 +3177,16 @@ def legacy_fully_connected(x, `activation_fn` is `None`, the result of `y = w * x + b` is returned. - If `x` has shape [\\\(\\text{dim}_0, \\text{dim}_1, ..., \\text{dim}_n\\\)] - with more than 2 dimensions (\\\(n > 1\\\)), then we repeat the matrix + If `x` has shape [\\(\text{dim}_0, \text{dim}_1, ..., \text{dim}_n\\)] + with more than 2 dimensions (\\(n > 1\\)), then we repeat the matrix multiply along the first dimensions. The result r is a tensor of shape - [\\\(\\text{dim}_0, ..., \\text{dim}_{n-1},\\\) `num_output_units`], - where \\\( r_{i_0, ..., i_{n-1}, k} = - \\sum_{0 \\leq j < \\text{dim}_n} x_{i_0, ... i_{n-1}, j} \cdot w_{j, k}\\\). + [\\(\text{dim}_0, ..., \text{dim}_{n-1},\\) `num_output_units`], + where \\( r_{i_0, ..., i_{n-1}, k} = + \sum_{0 \leq j < \text{dim}_n} x_{i_0, ... i_{n-1}, j} \cdot w_{j, k}\\). This is accomplished by reshaping `x` to 2-D - [\\\(\\text{dim}_0 \\cdot ... \\cdot \\text{dim}_{n-1}, \\text{dim}_n\\\)] + [\\(\text{dim}_0 \cdot ... \cdot \text{dim}_{n-1}, \text{dim}_n\\)] before the matrix multiply and afterwards reshaping it to - [\\\(\\text{dim}_0, ..., \\text{dim}_{n-1},\\\) `num_output_units`]. + [\\(\text{dim}_0, ..., \text{dim}_{n-1},\\) `num_output_units`]. This op creates `w` and optionally `b`. Bias (`b`) can be disabled by setting `bias_init` to `None`. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0f062adbab3ca9acfb89543b69c7c957bbdf5dd8..b01fd5d5c95ac15c76f9dbe7c77f7e76f12149a9 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -310,6 +310,17 @@ class BiasAddTest(test.TestCase): class ConvolutionTest(test.TestCase): + def testInvalidShape(self): + with self.test_session(): + images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1) + with self.assertRaisesRegexp( + ValueError, 'Convolution expects input with rank 5, got 4'): + layers_lib.convolution3d(images_2d, 32, 3) + images_3d = random_ops.random_uniform((5, 6, 7, 9, 3), seed=1) + with self.assertRaisesRegexp( + ValueError, 'Convolution expects input with rank 4, got 5'): + layers_lib.convolution2d(images_3d, 32, 3) + def testInvalidDataFormat(self): height, width = 7, 9 with self.test_session(): @@ -3155,7 +3166,7 @@ class RepeatTests(test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32) output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3]) - self.assertEqual(output.op.name, 'Repeat/convolution_3/Relu') + self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu') self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 32]) def testRepeatWithScope(self): @@ -3749,7 +3760,7 @@ class StackTests(test.TestCase): layers_lib.convolution2d, [10, 20, 30], kernel_size=[3, 3], padding='SAME') - self.assertEqual(output.op.name, 'Stack/convolution_3/Relu') + self.assertEqual(output.op.name, 'Stack/convolution2d_3/Relu') self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 30]) def testStackWithScope(self): @@ -4135,5 +4146,31 @@ class LegacyFullyConnectedTest(test.TestCase): _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax) +class MaxOutTest(test.TestCase): + + def test_simple(self): + inputs = random_ops.random_uniform((64, 10, 36), seed=1) + graph = _layers.maxout(inputs, num_units=3) + self.assertEqual(graph.get_shape().as_list(), [64, 10, 3]) + + def test_fully_connected(self): + inputs = random_ops.random_uniform((64, 50), seed=1) + graph = _layers.fully_connected(inputs, 50) + graph = _layers.maxout(graph, num_units=10) + self.assertEqual(graph.get_shape().as_list(), [64, 10]) + + def test_nchw(self): + inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) + graph = _layers.conv2d(inputs, 10, 3, padding='SAME') + graph = _layers.maxout(graph, num_units=1) + self.assertEqual(graph.get_shape().as_list(), [10, 100, 100, 1]) + + def test_invalid_shape(self): + inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) + graph = _layers.conv2d(inputs, 3, 10) + with self.assertRaisesRegexp(ValueError, 'number of features'): + graph = _layers.maxout(graph, num_units=2) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/layers/python/layers/normalization.py b/tensorflow/contrib/layers/python/layers/normalization.py index e7d4080ff769327cc74b6629a7705ddfa552169b..c807ab0f2e5c8ac3ec2ae1d84a5b36b5f4ba76a4 100644 --- a/tensorflow/contrib/layers/python/layers/normalization.py +++ b/tensorflow/contrib/layers/python/layers/normalization.py @@ -24,11 +24,13 @@ from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope __all__ = [ + 'group_norm', 'instance_norm', ] @@ -158,3 +160,196 @@ def instance_norm(inputs, if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + + +@add_arg_scope +def group_norm(inputs, + groups=32, + channels_axis=-1, + reduction_axes=(-3, -2), + center=True, + scale=True, + epsilon=1e-6, + activation_fn=None, + param_initializers=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Functional interface for the group normalization layer. + + Reference: https://arxiv.org/abs/1803.08494. + + "Group Normalization", Yuxin Wu, Kaiming He + + Args: + inputs: A Tensor with at least 2 dimensions one which is channels. All + shape dimensions must be fully defined. + groups: Integer. Divide the channels into this number of groups over which + normalization statistics are computed. This number must be commensurate + with the number of channels in `inputs`. + channels_axis: An integer. Specifies index of channels axis which will be + broken into `groups`, each of which whose statistics will be computed + across. Must be mutually exclusive with `reduction_axes`. Preferred usage + is to specify negative integers to be agnostic as to whether a batch + dimension is included. + reduction_axes: Tuple of integers. Specifies dimensions over which + statistics will be accumulated. Must be mutually exclusive with + `channels_axis`. Statistics will not be accumulated across axes not + specified in `reduction_axes` nor `channel_axis`. Preferred usage is to + specify negative integers to be agnostic to whether a batch dimension is + included. + + Some sample usage cases: + NHWC format: channels_axis=-1, reduction_axes=[-3, -2] + NCHW format: channels_axis=-3, reduction_axes=[-2, -1] + + center: If True, add offset of `beta` to normalized tensor. If False, `beta` + is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can be + disabled since the scaling can be done by the next layer. + epsilon: Small float added to variance to avoid dividing by zero. + activation_fn: Activation function, default set to None to skip it and + maintain a linear activation. + param_initializers: Optional initializers for beta, gamma, moving mean and + moving variance. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional collections for the variables. + outputs_collections: Collections to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + scope: Optional scope for `variable_scope`. + + Returns: + A `Tensor` representing the output of the operation. + + Raises: + ValueError: If the rank of `inputs` is undefined. + ValueError: If rank or channels dimension of `inputs` is undefined. + ValueError: If number of groups is not commensurate with number of channels. + ValueError: If reduction_axes or channels_axis are out of bounds. + ValueError: If reduction_axes are not mutually exclusive with channels_axis. + """ + # TODO(shlens): Support partially defined shapes for the inputs. + inputs = ops.convert_to_tensor(inputs) + original_shape = inputs.shape + + if inputs.shape.ndims is None: + raise ValueError('Inputs %s has undefined rank.' % inputs.name) + if channels_axis > (inputs.shape.ndims - 1): + raise ValueError('Axis is out of bounds.') + + # Standardize the channels_axis to be positive and identify # of channels. + if channels_axis < 0: + channels_axis = inputs.shape.ndims + channels_axis + channels = inputs.shape[channels_axis].value + + if channels is None: + raise ValueError('Inputs %s has undefined channel dimension: %d.' % ( + inputs.name, channels_axis)) + + # Standardize the reduction_axes to be positive. + reduction_axes = list(reduction_axes) + for i in range(len(reduction_axes)): + if reduction_axes[i] < 0: + reduction_axes[i] += inputs.shape.ndims + + for a in reduction_axes: + if a > inputs.shape.ndims: + raise ValueError('Axis is out of bounds.') + if inputs.shape[a].value is None: + raise ValueError('Inputs %s has undefined dimensions %d.' % ( + inputs.name, a)) + if channels_axis == a: + raise ValueError('reduction_axis must be mutually exclusive ' + 'with channels_axis') + if groups > channels: + raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) + if channels % groups != 0: + raise ValueError('%d channels is not commensurate with %d groups.' % + (channels, groups)) + + # Determine axes before channels. Some examples of common image formats: + # 'NCHW': before = [N], after = [HW] + # 'NHWC': before = [NHW], after = [] + axes_before_channels = inputs.shape.as_list()[:channels_axis] + axes_after_channels = inputs.shape.as_list()[channels_axis+1:] + + # Manually broadcast the parameters to conform to the number of groups. + params_shape_broadcast = ([1] * len(axes_before_channels) + + [groups, channels // groups] + + [1] * len(axes_after_channels)) + + # Reshape the input by the group within the channel dimension. + inputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + inputs = array_ops.reshape(inputs, inputs_shape) + + # Determine the dimensions across which moments are calculated. + moments_axes = [channels_axis + 1] + for a in reduction_axes: + if a > channels_axis: + moments_axes.append(a + 1) + else: + moments_axes.append(a) + + with variable_scope.variable_scope( + scope, 'GroupNorm', [inputs], reuse=reuse) as sc: + # Note that the params_shape is the number of channels always. + params_shape = [channels] + + # Allocate parameters for the beta and gamma of the normalization. + beta, gamma = None, None + dtype = inputs.dtype.base_dtype + if param_initializers is None: + param_initializers = {} + if center: + beta_collections = utils.get_variable_collections( + variables_collections, 'beta') + beta_initializer = param_initializers.get( + 'beta', init_ops.zeros_initializer()) + beta = variables.model_variable('beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) + beta = array_ops.reshape(beta, params_shape_broadcast) + + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get( + 'gamma', init_ops.ones_initializer()) + gamma = variables.model_variable('gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + gamma = array_ops.reshape(gamma, params_shape_broadcast) + + # Calculate the moments. + mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) + + # Compute normalization. + # TODO(shlens): Fix nn.batch_normalization to handle the 5-D Tensor + # appropriately so that this operation may be faster. + gain = math_ops.rsqrt(variance + epsilon) + offset = -mean * gain + if gamma is not None: + gain *= gamma + offset *= gamma + if beta is not None: + offset += beta + outputs = inputs * gain + offset + + # Collapse the groups into the channel dimension. + outputs = array_ops.reshape(outputs, original_shape) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py index 5cff1bf0ebb2fe8bc6933de882ecd47a9edf0f94..b6e96350db92baf4770683273be7e5dde73dbcec 100644 --- a/tensorflow/contrib/layers/python/layers/normalization_test.py +++ b/tensorflow/contrib/layers/python/layers/normalization_test.py @@ -166,5 +166,231 @@ class InstanceNormTest(test.TestCase): def testOutputBigInput5DNCHW(self): self.doOutputTest((1, 100, 100, 1, 1), 'NCHW', tol=1e-3) + +class GroupNormTest(test.TestCase): + + def testInvalidGroupSize(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 2, 10, 10)) + with self.assertRaisesRegexp(ValueError, + 'Invalid groups 10 for 2 channels.'): + normalization.group_norm(inputs, groups=10, + reduction_axes=[-2, -1], channels_axis=-3) + + def testBadCommensurateGroup(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, 4, 10, 10)) + with self.assertRaisesRegexp(ValueError, + '4 channels is not commensurate with ' + '3 groups.'): + normalization.group_norm(inputs, groups=3, + reduction_axes=[-2, -1], channels_axis=-3) + + def testAxisIsBad(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 2, 4, 5)) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, channels_axis=5) + with self.assertRaisesRegexp(ValueError, + 'Axis is out of bounds.'): + normalization.group_norm(inputs, reduction_axes=[1, 5]) + + def testNotMutuallyExclusiveAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(10, 32, 32, 32)) + # Specify axis with negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[-2]) + # Specify axis with positive values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=1, reduction_axes=[1, 3]) + # Specify axis with mixed positive and negative values. + with self.assertRaisesRegexp(ValueError, 'mutually exclusive'): + normalization.group_norm(inputs, channels_axis=-2, reduction_axes=[2]) + + def testUnknownShape(self): + inputs = array_ops.placeholder(dtypes.float32) + with self.assertRaisesRegexp(ValueError, 'undefined rank'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedReductionAxes(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 32, None, 4)) + with self.assertRaisesRegexp(ValueError, 'undefined dimensions'): + normalization.group_norm(inputs) + + def testParamsShapeNotFullyDefinedChannelsAxis(self): + inputs = array_ops.placeholder(dtypes.float32, shape=(1, 3, 4, None)) + with self.assertRaisesRegexp(ValueError, 'undefined channel dimension'): + normalization.group_norm(inputs, channels_axis=-1, + reduction_axes=[-3, -2]) + + def testCreateOp(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, height, width, 2*groups), seed=1) + output = normalization.group_norm(images, groups=groups, channels_axis=-1, + reduction_axes=[-3, -2]) + print('name: ', output.op.name) + self.assertListEqual([5, height, width, 2*groups], output.shape.as_list()) + + def testCreateOpFloat64(self): + height, width, groups = 3, 3, 5 + images = random_ops.random_uniform( + (5, height, width, 4*groups), dtype=dtypes.float64, seed=1) + output = normalization.group_norm(images, groups=groups) + self.assertEqual(dtypes.float64, output.dtype) + self.assertListEqual([5, height, width, 4*groups], output.shape.as_list()) + + def testCreateOpNoScaleCenter(self): + height, width, groups = 3, 3, 7 + images = random_ops.random_uniform( + (5, height, width, 3*groups), dtype=dtypes.float32, seed=1) + output = normalization.group_norm(images, groups=groups, center=False, + scale=False) + self.assertListEqual([5, height, width, 3*groups], output.shape.as_list()) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('beta'))) + self.assertEqual(0, len(contrib_variables.get_variables_by_name('gamma'))) + + def testCreateVariables_NHWC(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 8), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-1, reduction_axes=(-3, -2), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testCreateVariables_NCHW(self): + height, width, groups = 3, 3, 4 + images = random_ops.random_uniform((5, 2*groups, height, width), seed=1) + normalization.group_norm(images, groups=4, + channels_axis=-3, reduction_axes=(-2, -1), + center=True, scale=True) + beta = contrib_variables.get_variables_by_name('beta')[0] + gamma = contrib_variables.get_variables_by_name('gamma')[0] + self.assertEqual('GroupNorm/beta', beta.op.name) + self.assertEqual('GroupNorm/gamma', gamma.op.name) + + def testReuseVariables(self): + height, width = 3, 3 + images = random_ops.random_uniform((5, height, width, 4), seed=1) + normalization.group_norm(images, groups=2, scale=True, scope='IN') + normalization.group_norm(images, groups=2, scale=True, scope='IN', + reuse=True) + beta = contrib_variables.get_variables_by_name('beta') + gamma = contrib_variables.get_variables_by_name('gamma') + self.assertEqual(1, len(beta)) + self.assertEqual(1, len(gamma)) + + def testValueCorrectWithReuseVars(self): + height, width = 3, 3 + image_shape = (10, height, width, 4) + images = random_ops.random_uniform(image_shape, seed=1) + output_train = normalization.group_norm(images, groups=2, scope='IN') + output_eval = normalization.group_norm(images, groups=2, scope='IN', + reuse=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + # output_train and output_eval should be the same. + train_np, eval_np = sess.run([output_train, output_eval]) + self.assertAllClose(train_np, eval_np) + + def doOutputTest(self, input_shape, channels_axis=None, reduction_axes=None, + groups=2, tol=1e-2): + # Select the axis for the channel and the dimensions along which statistics + # are accumulated. + if channels_axis < 0: + channels_axis += len(input_shape) + reduced_axes = [channels_axis + 1] + for a in reduction_axes: + if a < 0: + a += len(input_shape) + if a < channels_axis: + reduced_axes.append(a) + else: + reduced_axes.append(a+1) + reduced_axes = tuple(reduced_axes) + + # Calculate the final shape for the output Tensor. + axes_before_channels = input_shape[:channels_axis] + axes_after_channels = input_shape[channels_axis+1:] + channels = input_shape[channels_axis] + outputs_shape = (axes_before_channels + [groups, channels // groups] + + axes_after_channels) + + # Calculate the final shape for the output statistics. + reduced_shape = [] + for i, a in enumerate(outputs_shape): + if i not in reduced_axes: + reduced_shape.append(a) + + for mu in (0.0, 1e2): + for sigma in (1.0, 0.1): + # Determine shape of Tensor after normalization. + expected_mean = np.zeros(reduced_shape) + expected_var = np.ones(reduced_shape) + + inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu + output_op = normalization.group_norm( + inputs, groups=groups, center=False, scale=False, + channels_axis=channels_axis, + reduction_axes=reduction_axes) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outputs = sess.run(output_op) + # Make sure that there are no NaNs + self.assertFalse(np.isnan(outputs).any()) + + outputs = np.reshape(outputs, outputs_shape) + mean = np.mean(outputs, axis=reduced_axes) + var = np.var(outputs, axis=reduced_axes) + # The mean and variance of each example should be close to 0 and 1 + # respectively. + self.assertAllClose(expected_mean, mean, rtol=tol, atol=tol) + self.assertAllClose(expected_var, var, rtol=tol, atol=tol) + + def testOutputSmallInput4D_NHWC(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=3, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput3D_NHWC(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=2, reduction_axes=[0, 1]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-1, reduction_axes=[-3, -2]) + + def testOutputSmallInput4D_NCHW(self): + input_shape = [10, 10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=1, reduction_axes=[2, 3]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputSmallInput3D_NCHW(self): + input_shape = [10, 10, 30] + # Specify axes with positive values. + self.doOutputTest(input_shape, channels_axis=0, reduction_axes=[1, 2]) + # Specify axes with negative values. + self.doOutputTest(input_shape, channels_axis=-3, reduction_axes=[-2, -1]) + + def testOutputBigInput4D_NHWC(self): + self.doOutputTest([5, 100, 100, 1], channels_axis=3, reduction_axes=[1, 2], + groups=1) + + def testOutputBigInput4D_NCHW(self): + self.doOutputTest([1, 100, 100, 4], channels_axis=1, reduction_axes=[2, 3], + groups=4) + + def testOutputSmallInput2D_NC(self): + self.doOutputTest([10, 7*100], channels_axis=1, reduction_axes=[], groups=7) + + def testOutputSmallInput5D_NCXXX(self): + self.doOutputTest([10, 10, 20, 40, 5], + channels_axis=1, + reduction_axes=[2, 3, 4], + groups=5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index cdceea6fee5bdb5aeb6537ea55d25ccf107def4c..69d927e1b3001d14dd1af2f890b07c1a57ab2cfc 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -41,7 +41,7 @@ OPTIMIZER_CLS_NAMES = { "Adagrad": train.AdagradOptimizer, "Adam": train.AdamOptimizer, "Ftrl": train.FtrlOptimizer, - "Momentum": lambda lr: train.MomentumOptimizer(lr, momentum=0.9), + "Momentum": lambda learning_rate: train.MomentumOptimizer(learning_rate, momentum=0.9), # pylint: disable=line-too-long "RMSProp": train.RMSPropOptimizer, "SGD": train.GradientDescentOptimizer, } diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 1ea25bd1a5685eb6f840e621b5739029a660aa0f..a4461a20e54c289886f1a1beb255de12fc054afe 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -61,7 +61,8 @@ class OptimizersTest(test.TestCase): optimizers = [ "SGD", gradient_descent.GradientDescentOptimizer, gradient_descent.GradientDescentOptimizer(learning_rate=0.1), - lambda lr: gradient_descent.GradientDescentOptimizer(learning_rate=lr) + lambda lr: gradient_descent.GradientDescentOptimizer(learning_rate=lr), + "Momentum" ] for optimizer in optimizers: with ops.Graph().as_default() as g: diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 123275e1fde047cd3772528641b2e3b09742fbdc..c4fa3392ef9d5c5dd3760e55f37ed565580918d4 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -29,14 +29,17 @@ from __future__ import print_function import functools import re +import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -46,6 +49,7 @@ from tensorflow.python.util import nest __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") +_USE_DEFAULT = "__rev_block_lib_default" def _acc_grads(*lists_of_grads): @@ -219,7 +223,13 @@ class RevBlock(base.Layer): def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" + # Inputs have passed through an Identity. Recover the original Tensors to + # be able to match up side inputs. + assert [u"Identity"] == list(set([x.op.type for x in inputs])) + inputs = [x.op.inputs[0] for x in inputs] side_inputs = inputs[2:] + del inputs + f_side_idxs = [None] * len(self.f_side_input) g_side_idxs = [None] * len(self.g_side_input) assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) @@ -237,9 +247,7 @@ class RevBlock(base.Layer): f_vars_idxs = [[] for _ in range(self.num_layers)] g_vars_idxs = [[] for _ in range(self.num_layers)] - for i, t in enumerate(variables): - ref = _underlying_variable_ref(t) - + for i, ref in enumerate(variables): # Use the name to identify the layer number and function (f or g) regex = LAYER_RE.match(ref.name) layer_no = int(regex.group(1)) @@ -405,12 +413,36 @@ def rev_block(x1, return block.forward(x1, x2) -def recompute_grad(fn): +def enable_with_args(dec): + """A decorator for decorators to enable their usage with or without args.""" + + @functools.wraps(dec) + def new_dec(*args, **kwargs): + if len(args) == 1 and not kwargs and callable(args[0]): + # Used as decorator without args + fn = args[0] + return dec(fn) + else: + return lambda fn: dec(fn, *args, **kwargs) + + return new_dec + + +@enable_with_args +def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. + use_data_dep: `bool`, if `True` will use a dummy data dependency to force + the recompute to happen. If `False` will use a control dependency. By + default will be `True` if in an XLA context and `False` otherwise. XLA + ignores control dependencies and so this data dependency is necessary. + tupleize_grads: `bool`, if `True` will use control dependencies to ensure + that all gradients are produced before any are consumed by downstream ops. + If `use_data_dep` is also `True`, will use a data dependency instead of + a control dependency. Returns: A wrapped fn that is identical to fn when called, but its activations will @@ -420,13 +452,25 @@ def recompute_grad(fn): @functools.wraps(fn) def wrapped(*args): - return _recompute_grad(fn, args) + return _recompute_grad( + fn, args, use_data_dep=use_data_dep, tupleize_grads=tupleize_grads) return wrapped -def _recompute_grad(fn, args): +def _is_on_tpu(): + ctxt = framework_ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + return control_flow_util.GetContainingXLAContext(ctxt) is not None + + +def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + for arg in args: + if not isinstance(arg, framework_ops.Tensor): + raise ValueError("All inputs to function must be Tensors") + use_data_dep_ = use_data_dep + if use_data_dep_ == _USE_DEFAULT: + use_data_dep_ = _is_on_tpu() cached_vs = [] cached_arg_scope = [] @@ -436,6 +480,8 @@ def _recompute_grad(fn, args): del outputs # Recompute outputs with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) with contrib_framework_ops.arg_scope(cached_arg_scope[0]): with variable_scope.variable_scope(cached_vs[0], reuse=True): outputs = fn(*inputs) @@ -444,6 +490,13 @@ def _recompute_grad(fn, args): outputs = [outputs] outputs = list(outputs) grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):] return grad_inputs, grad_vars @@ -451,11 +504,7 @@ def _recompute_grad(fn, args): @_fn_with_custom_grad(grad_fn) def fn_with_recompute(*args): cached_vs.append(variable_scope.get_variable_scope()) - # TODO(rsepassi): Rm conditional in TF 1.4 - if hasattr(contrib_framework_ops, "current_arg_scope"): - cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) - else: - cached_arg_scope.append({}) + cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) return fn(*args) return fn_with_recompute(*args) @@ -532,7 +581,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): get_vars_fn = ( vs.global_variables if use_global_vars else vs.trainable_variables) len_before_vars = len(get_vars_fn()) - inputs = list(inputs) + inputs = [array_ops.identity(x) for x in inputs] outputs = fn(*inputs) train_vars = get_vars_fn()[len_before_vars:] @@ -549,6 +598,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): """Custom grad fn applying grad_fn for identity Defun.""" fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( defun_inputs, list(op.inputs)) + fn_vars = [_underlying_variable_ref(v) for v in fn_vars] dys = list(dys) assert len(fn_outputs) == len(outputs) assert len(fn_outputs) == len(dys) @@ -581,3 +631,48 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): flat_inputs = nest.flatten(defun_inputs) id_out = identity(*flat_inputs) return id_out + + +def _force_data_dependency(first_compute, then_compute): + """Force all of `then_compute` to depend on all of `first_compute`. + + Uses a dummy data dependency, which is useful when running on TPUs because + XLA ignores control dependencies. Only supports float arguments. + + Args: + first_compute: `list`. These will be made to run before the + `Tensor`s `then_compute`. + then_compute: `list`. These will run after all the `Tensor`s in + `first_compute`. + + Returns: + `list`, same length as `then_compute`. + + Raises: + ValueError: if ranks are unknown or types are not floating. + """ + + def _first_element(x): + if x.get_shape().ndims is None: + raise ValueError("Rank of Tensor %s must be known" % x) + ndims = x.get_shape().ndims + begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32) + size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32) + return array_ops.reshape(array_ops.slice(x, begin, size), []) + + first_compute_sum = math_ops.add_n( + [_first_element(x) for x in first_compute if x is not None]) + dtype = first_compute_sum.dtype + if not dtype.is_floating: + raise ValueError("_force_data_dependency only supports floating dtypes.") + epsilon = np.finfo(dtype.as_numpy_dtype).tiny + zero = array_ops.stop_gradient(epsilon * first_compute_sum) + + return [ + array_ops.identity(x) + zero if x is not None else None + for x in then_compute + ] + + +def _tuple_with_data_dep(tensors): + return _force_data_dependency(tensors, tensors) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index cbcbcd75114a522b95631e4e7e95c1641b0a9987..8c118402a4c85d4b0504754fcd0436ce8b00862d 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -60,8 +60,8 @@ class RevBlockTest(test.TestCase): sess.run(variables.global_variables_initializer()) x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) - self.assertAllClose(x1, x1_inv) - self.assertAllClose(x2, x2_inv) + self.assertAllClose(x1, x1_inv, atol=1e-5) + self.assertAllClose(x2, x2_inv, atol=1e-5) def testBackwardForward(self): @@ -154,7 +154,7 @@ class RevBlockTest(test.TestCase): y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) self.assertAllClose(y_val, yd_val) for g1, g2 in zip(gd_val, g_val): - self.assertAllClose(g1, g2) + self.assertAllClose(g1, g2, rtol=1e-5) def testRevBlock(self): self._testRevBlock() @@ -255,25 +255,68 @@ class RecomputeTest(test.TestCase): def fn_recompute(x): return fn(x) + @rev_block_lib.recompute_grad(use_data_dep=True) + def fn_use_data_dep(x): + return fn(x) + + @rev_block_lib.recompute_grad(tupleize_grads=True) + def fn_tupleize(x): + return fn(x) + + @rev_block_lib.recompute_grad(use_data_dep=True, tupleize_grads=True) + def fn_both(x): + return fn(x) + x = random_ops.random_uniform((3, 1, 3)) - recompute_vars = None - with variable_scope.variable_scope("recompute") as vs: - out1 = math_ops.reduce_sum(fn_recompute(x)) - recompute_vars = vs.trainable_variables() - reg_vars = None - with variable_scope.variable_scope("regular") as vs: - out2 = math_ops.reduce_sum(fn(x)) - reg_vars = vs.trainable_variables() - - grad1 = gradients_impl.gradients(out1, recompute_vars) - grad2 = gradients_impl.gradients(out2, reg_vars) + + names_and_fns = [ + ("recompute", fn_recompute), + ("regular", fn), + ("use_data_dep", fn_use_data_dep), + ("tupleize", fn_tupleize), + ("tuple_and_data_dep", fn_both), + ] + outputs_and_vars = [] + for name, wrapped_fn in names_and_fns: + with variable_scope.variable_scope(name) as vs: + out = math_ops.reduce_sum(wrapped_fn(x)) + outputs_and_vars.append((out, vs.trainable_variables())) + + all_grads = [] + for out, scope_vars in outputs_and_vars: + all_grads.append(gradients_impl.gradients(out, scope_vars)) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) - outs = sess.run([out1, out2, grad1, grad2]) - self.assertAllClose(outs[0], outs[1]) - for g1, g2 in zip(outs[2], outs[3]): - self.assertAllClose(g1, g2) + outputs = list(zip(*outputs_and_vars))[0] + outs, all_grads_val = sess.run([outputs, all_grads]) + + # All outputs are the same + current = outs[0] + for out in outs[1:]: + self.assertAllClose(current, out) + current = out + + # All gradients are the same + for grads in zip(all_grads_val): + current = grads[0] + for g in grads[1:]: + self.assertAllClose(current, g) + current = g + + def testResourceVariable(self): + @rev_block_lib.recompute_grad(tupleize_grads=True) + def layer_with_recompute(inputs): + var = variable_scope.get_variable("var", ()) + return var * inputs + + inputs = array_ops.ones((), dtypes.float32) + with variable_scope.variable_scope("layer", use_resource=True): + outputs = layer_with_recompute(inputs) + loss = math_ops.square(outputs) + grads = gradients_impl.gradients(loss, variables.trainable_variables()) + self.assertEqual(1, len(grads)) + self.assertTrue(grads[0] is not None) class FnWithCustomGradTest(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py index 3409860add8f8c393ffd342633e7023931867dd9..645dc1291eb6370a5e504306fc00a5454dde77ed 100644 --- a/tensorflow/contrib/layers/python/layers/utils_test.py +++ b/tensorflow/contrib/layers/python/layers/utils_test.py @@ -294,7 +294,6 @@ class NPositiveIntegersTest(test.TestCase): self.assertEqual(utils.n_positive_integers(2, 2), (2, 2)) self.assertEqual(utils.n_positive_integers(2, (2, 3)), (2, 3)) self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) - self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) self.assertEqual( utils.n_positive_integers(3, tensor_shape.TensorShape([2, 3, 1])), (2, 3, 1)) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index abf6e393bb0fbbce4e43f6d209e9b30517df36c3..d665fc9335cf22cdfa1e7330ab67003042502515 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -5,6 +5,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load("//tensorflow:tensorflow.bzl", "py_test") + package(default_visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", "//tensorflow:internal", @@ -115,6 +117,7 @@ py_test( size = "small", srcs = ["python/learn/learn_io/data_feeder_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/python:client_testlib", @@ -170,6 +173,7 @@ tf_py_test( "//tensorflow/python:variables", "//tensorflow/python/estimator", ], + tags = ["no_windows"], # TODO: needs investigation on Windows ) py_test( @@ -188,6 +192,7 @@ py_test( size = "small", srcs = ["python/learn/graph_actions_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -224,6 +229,7 @@ py_test( size = "small", srcs = ["python/learn/monitors_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip_gpu"], # b/74437598 deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -426,6 +432,10 @@ py_test( size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], srcs_version = "PY2AND3", + tags = [ + "noasan", # b/73741358 + "nomac", + ], deps = [ ":learn", "//tensorflow/python:array_ops", @@ -584,6 +594,7 @@ py_test( size = "small", srcs = ["python/learn/learn_io/io_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/learn/python/learn/datasets", @@ -813,6 +824,7 @@ py_test( size = "small", srcs = ["python/learn/utils/saved_model_export_utils_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], # TODO: needs investigation on Windows deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -867,15 +879,3 @@ py_binary( "//tensorflow/python:platform", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d516bffc5e0327a3400068b35de5503e5a925a54 --- /dev/null +++ b/tensorflow/contrib/learn/README.md @@ -0,0 +1,143 @@ +EVERYTHING IN THIS DIRECTORY IS DEPRECATED. + +Using functions or classes will result in warnings. + +Instructions for converting to current alternatives are included in the +warnings. A high-level overview is below. + +## Canned Estimators + +Many canned estimators (subclasses of `Estimator`) have equivalents in core: +`DNNClassifier`, `DNNRegressor`, `DNNEstimator`, `LinearClassifier`, +`LinearRegressor`, `DNNLinearCombinedClassifier` and +`DNNLinearCombinedRegressor`. They are exposed under `tf.estimator`. +`DNNEstimator`, `LinearEstimator` and `DNNLinearCombinedEstimator` +are exposed under `tf.contrib.estimator`. + +To migrate to the new api, users need to take the following steps: + +* Replace `tf.contrib.learn` with `tf.estimator`. +* If you subclass any of the estimators, stop doing that. You should be able to + write a factory method that returns a canned estimator instead. If this is not + possible (if you override methods from the canned estimator), consider writing + a custom estimator instead. See `tf.estimator.Estimator`. +* Set `loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE` to preserve loss + reduction as the average over batch. +* Some optimizer-related arguments are no longer passed in the estimator + constructor. Instead, we provide methods that perform the same job by wrapping + an optimizer. Specifically: + * `gradient_clip_norm`: Use `tf.contrib.estimator.clip_gradients_by_norm` + * `embedding_lr_multipliers`: Not supported. + Other arguments: + * `input_layer_min_slice_size`: Replaced by `input_layer_partitioner` + * `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. + * `feature_engineering_fn`: Not supported. You can call your + `feature_engineering_fn` inside your input_fn: + ```python + def new_input_fn(): + features, labels = old_input_fn() + return feature_engineering_fn(features, labels) + ``` +* Use `tf.reshape` to reshape labels in your `input_fn`. `tf.estimator` + classifiers and regressors expect labels as a 2D Tensor of shape + `[batch_size, 1]`, or `[batch_size, n_labels]`. In contrast, + `tf.contrib.learn` classifiers and regressors supported labels with shape + `[batch_size]`. +* If you pass custom metrics from the `evaluate()` method call, use + `tf.contrib.estimator.add_metrics`. +* Replace your `serving_input_fn` with a `serving_input_receiver_fn`. + Note this should be entirely distinct from your training `input_fn`, so if you + previously had one `input_fn` with different "modes", you should now factor + that apart. Where the former returned either a simple `(features, labels)` + tuple or `InputFnOps`, you should now return a `ServingInputReceiver`. + If you were generating your `serving_input_fn` using the + `build_parsing_serving_input_fn` helper, you can simply drop in the + replacement `build_parsing_serving_input_receiver_fn`. + +Some remaining estimators/classes: + +* `DynamicRnnEstimator`: Consider a custom `model_fn`. +* `KMeansClustering`: Use `tf.contrib.factorization.KMeansClustering`. +* `LogisticRegressor`: Not supported. Instead, use `binary_classification_head` + with a custom `model_fn`, or with `DNNEstimator`. +* `StateSavingRnnEstimator`: Consider a custom `model_fn`. +* SVM: Consider a custom `model_fn`. +* `LinearComposableModel` and `DNNComposableModel`: Not supported. + Consider `tf.contrib.estimator.DNNEstimator`, or write a custom model_fn. +* `MetricSpec`: Deprecated. For adding custom metrics to canned Estimators, use + `tf.contrib.estimator.add_metrics`. + +## Estimator +`tf.contrib.learn.Estimator` is migrated to `tf.estimator.Estimator`. + +To migrate, users need to take the following steps: + +* Replace `tf.contrib.learn.Estimator` with `tf.estimator.Estimator`. +* If you pass a `config` argument to `Estimator`, this must be + `tf.estimator.RunConfig`. You may need to edit your code accordingly. +* Edit your `model_fn` to return `tf.estimator.EstimatorSpec`. Refer to + `EstimatorSpec` for documentation of specific fields. +* If your `model_fn` uses the `mode` argument, use `tf.estimator.ModeKeys`. + +Some related classes: +* `Evaluable`, `Trainable`: Not supported, merged into `tf.estimator.Estimator`. +* ExportStrategy: Replaced by `tf.estimator.Exporter`. + +## Head/MultiHead +These classes are now supported under `tf.contrib.estimator`, e.g. +`tf.contrib.estimator.multi_class_head` and `tf.contrib.estimator.multi_head`. + +Some differences: + +* `multi_class_head`: If you use `tf.contrib.learn.multi_class_head` with + `n_classes=2`, switch to `tf.contrib.estimator.binary_classification_head`. +* `loss_only_head`: Not supported. +* `poisson_regression_head`: Not supported (yet). +* `binary_svm_head`: Not supported (yet). +* `no_op_train_fn`: Replace it with `tf.no_op`. + +Some arguments are renamed, please refer to documentation. In addition: + +* `loss_fn`: Supported for `multi_label_head`. If you need it for other heads, + please open an issue. +* `metric_class_ids`: Not supported (yet). +* `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. +* `label_name`: Not needed in `tf.estimator`. If you don’t use `multi_head`, + drop this argument. If you use `multi_head`, refer to + `tf.contrib.estimator.multi_head` documentation. + +## Experiment Class - Distributed Training Tooling + +Switch to `tf.estimator.train_and_evaluate`. Some differences: + +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, + `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` + directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement + for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to export + only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirementds (in + particular, the chief node and evaluator node). + +## Others Classes and Functions + +* `tf.contrib.learn.datasets` is deprecated. We are adding ready to use datasets + to tensorflow/models. Many smaller datasets are available from other sources, + such as scikits.learn. Some Python processing may have to be written, but this + is straightforward to implement using the standard modules. +* `tf.contrib.learn.preprocessing`: Deprecated. The python-only preprocessing + functions are not a good fit for TensorFlow. Please use `tf.data`, and + consider tensorflow/transform for more complex use cases. +* `tf.contrib.learn.models`: Not supported, use canned estimators instead. +* `tf.contrib.learn.monitors`: Implement `SessionRunHook` instead. Hook + implementations are in `tf.train`. +* `tf.contrib.learn.learn_io`: Use the methods in `tf.estimator.inputs`, such as + `tf.estimator.inputs.numpy_input_fn`. Some utility functions have no + equivalent, we encourage the use of `tf.data`. + diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 3698af027e38f1063ad829c26eb179734968f813..79bd73faaf1301a2fc4999b64f88d30542577980 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================== -# TODO(ptucker,ipolosukhin): Improve descriptions. -"""High level API for learning. +"""High level API for learning (DEPRECATED). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. See the @{$python/contrib.learn} guide. diff --git a/tensorflow/contrib/learn/python/__init__.py b/tensorflow/contrib/learn/python/__init__.py index bbebd5ab9792cb937219cf937f08c4d4e6e44a92..df23aeb2c433c2b4392f706730f715246ce01cea 100644 --- a/tensorflow/contrib/learn/python/__init__.py +++ b/tensorflow/contrib/learn/python/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index cdc67c77d5fd1df61016835dc75ba44feb458cf9..76e0e8ac8f19026086959f3b197cfd1a81e65a3e 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py index 2284ec46e971731af74f17678fc0d1d3888419e2..fed1c44d1970bf07c808ace817aa9972d7776d88 100644 --- a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py +++ b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py @@ -12,20 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Some common SessionRunHook classes.""" +"""Some common SessionRunHook classes (deprected). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.util.deprecation import deprecated_alias # pylint: disable=invalid-name -LoggingTensorHook = basic_session_run_hooks.LoggingTensorHook -StopAtStepHook = basic_session_run_hooks.StopAtStepHook -CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook -StepCounterHook = basic_session_run_hooks.StepCounterHook -NanLossDuringTrainingError = basic_session_run_hooks.NanLossDuringTrainingError -NanTensorHook = basic_session_run_hooks.NanTensorHook -SummarySaverHook = basic_session_run_hooks.SummarySaverHook +LoggingTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.LoggingTensorHook', + 'tf.train.LoggingTensorHook', + basic_session_run_hooks.LoggingTensorHook) +StopAtStepHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StopAtStepHook', + 'tf.train.StopAtStepHook', + basic_session_run_hooks.StopAtStepHook) +CheckpointSaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.CheckpointSaverHook', + 'tf.train.CheckpointSaverHook', + basic_session_run_hooks.CheckpointSaverHook) +StepCounterHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StepCounterHook', + 'tf.train.StepCounterHook', + basic_session_run_hooks.StepCounterHook) +NanLossDuringTrainingError = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanLossDuringTrainingError', + 'tf.train.NanLossDuringTrainingError', + basic_session_run_hooks.NanLossDuringTrainingError) +NanTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanTensorHook', + 'tf.train.NanTensorHook', + basic_session_run_hooks.NanTensorHook) +SummarySaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.SummarySaverHook', + 'tf.train.SummarySaverHook', + basic_session_run_hooks.SummarySaverHook) # pylint: enable=invalid-name diff --git a/tensorflow/contrib/learn/python/learn/datasets/BUILD b/tensorflow/contrib/learn/python/learn/datasets/BUILD index 8bf372841d04dc9e1339925474801d5aa3af4ccd..2c7215bba3816ff3762e5b7927f650d1c9cbf617 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/BUILD +++ b/tensorflow/contrib/learn/python/learn/datasets/BUILD @@ -44,18 +44,6 @@ py_binary( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - py_test( name = "base_test", size = "small", diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index 7240b0de149051afa045a8113f9e9b212840c311..3c34712ac859d32f549468345950a93d2ed2aa56 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Dataset utilities and synthetic/reference datasets.""" +"""Dataset utilities and synthetic/reference datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.datasets import mnist from tensorflow.contrib.learn.python.learn.datasets import synthetic from tensorflow.contrib.learn.python.learn.datasets import text_datasets +from tensorflow.python.util.deprecation import deprecated # Export load_iris and load_boston. load_iris = base.load_iris @@ -51,6 +57,7 @@ SYNTHETIC = { } +@deprecated(None, 'Please use tf.data.') def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -73,8 +80,9 @@ def load_dataset(name, size='small', test_with_fake_data=False): return DATASETS[name]() +@deprecated(None, 'Please use tf.data.') def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): - """Creates binary synthetic datasets + """Creates binary synthetic datasets. Args: name: str, name of the dataset to generate diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index ca720ae5ed26e74da12bd6c5a37231b41442f76f..4676eedb206147d178c6a652aa7c2cb48ef888c0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base utilities for loading datasets.""" + +"""Base utilities for loading datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +35,14 @@ import numpy as np from six.moves import urllib from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated + Dataset = collections.namedtuple('Dataset', ['data', 'target']) Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test']) +@deprecated(None, 'Use tf.data instead.') def load_csv_with_header(filename, target_dtype, features_dtype, @@ -53,6 +62,7 @@ def load_csv_with_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def load_csv_without_header(filename, target_dtype, features_dtype, @@ -70,6 +80,7 @@ def load_csv_without_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def shrink_csv(filename, ratio): """Create a smaller dataset of only 1/ratio of original data.""" filename_small = filename.replace('.', '_small.') @@ -84,6 +95,7 @@ def shrink_csv(filename, ratio): i += 1 +@deprecated(None, 'Use scikits.learn.datasets.') def load_iris(data_path=None): """Load Iris dataset. @@ -100,6 +112,7 @@ def load_iris(data_path=None): data_path, target_dtype=np.int, features_dtype=np.float) +@deprecated(None, 'Use scikits.learn.datasets.') def load_boston(data_path=None): """Load Boston housing dataset. @@ -116,20 +129,58 @@ def load_boston(data_path=None): data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): +@deprecated(None, 'Use the retry module or similar alternatives.') +def retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): """Simple decorator for wrapping retriable functions. Args: initial_delay: the initial delay. + max_delay: the maximum delay allowed (actual max is + max_delay * (1 + jitter). factor: each subsequent retry, the delay is multiplied by this value. (must be >= 1). jitter: to avoid lockstep, the returned delay is multiplied by a random number between (1-jitter) and (1+jitter). To add a 20% jitter, set jitter = 0.2. Must be < 1. + is_retriable: (optional) a function that takes an Exception as an argument + and returns true if retry should be applied. + + Returns: + A function that wraps another function to automatically retry it. + """ + return _internal_retry( + initial_delay=initial_delay, + max_delay=max_delay, + factor=factor, + jitter=jitter, + is_retriable=is_retriable) + + +def _internal_retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): + """Simple decorator for wrapping retriable functions, for internal use only. + + Args: + initial_delay: the initial delay. max_delay: the maximum delay allowed (actual max is max_delay * (1 + jitter). + factor: each subsequent retry, the delay is multiplied by this value. + (must be >= 1). + jitter: to avoid lockstep, the returned delay is multiplied by a random + number between (1-jitter) and (1+jitter). To add a 20% jitter, set + jitter = 0.2. Must be < 1. is_retriable: (optional) a function that takes an Exception as an argument and returns true if retry should be applied. + + Returns: + A function that wraps another function to automatically retry it. """ if factor < 1: raise ValueError('factor must be >= 1; was %f' % (factor,)) @@ -152,7 +203,7 @@ def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): for delay in delays(): try: return fn(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except) + except Exception as e: # pylint: disable=broad-except if is_retriable is None: continue @@ -176,11 +227,13 @@ def _is_retriable(e): return isinstance(e, IOError) and e.errno in _RETRIABLE_ERRNOS -@retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) +@deprecated(None, 'Please use urllib or similar directly.') +@_internal_retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) def urlretrieve_with_retry(url, filename=None): return urllib.request.urlretrieve(url, filename) +@deprecated(None, 'Please write your own downloading logic.') def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 37f9175015a239f763c7721cf36ab8063c0a3e32..abbb44c2f5b701829ce16f64eadd8ebc04c84e2c 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for downloading and reading MNIST data.""" +"""Functions for downloading and reading MNIST data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated # CVDF mirror of http://yann.lecun.com/exdb/mnist/ DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' @@ -37,6 +43,7 @@ def _read32(bytestream): return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. @@ -65,6 +72,7 @@ def extract_images(f): return data +@deprecated(None, 'Please use tf.one_hot on tensors.') def dense_to_one_hot(labels_dense, num_classes): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] @@ -74,6 +82,7 @@ def dense_to_one_hot(labels_dense, num_classes): return labels_one_hot +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. @@ -103,7 +112,15 @@ def extract_labels(f, one_hot=False, num_classes=10): class DataSet(object): + """Container class for a dataset (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def __init__(self, images, labels, @@ -210,6 +227,8 @@ class DataSet(object): return self._images[start:end], self._labels[start:end] +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def read_data_sets(train_dir, fake_data=False, one_hot=False, @@ -275,5 +294,7 @@ def read_data_sets(train_dir, return base.Datasets(train=train, validation=validation, test=test) +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def load_mnist(train_dir='MNIST-data'): return read_data_sets(train_dir) diff --git a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py index 6e0ba38941ce4650ede9f7210e284bde2ed8e6a9..a4848fa64a72f031ef35c0c3256e97a7326acd60 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Produce DBpedia datasets of a smaller size.""" +"""Produce DBpedia datasets of a smaller size (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py index 9a843168c27d9cae3f55efe4fe4c688d86c745f3..6a0e3350b3d1052249160a2a997a76de7a5040c3 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Synthetic dataset generators.""" +"""Synthetic dataset generators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,8 +26,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.learn.python.learn.datasets.base import Dataset +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def circles(n_samples=100, noise=None, seed=None, @@ -93,6 +100,7 @@ def circles(n_samples=100, return Dataset(data=X[indices], target=y[indices]) +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def spirals(n_samples=100, noise=None, seed=None, diff --git a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py index 2596a2ecaf1572506504831e8b08fab9b5dbc119..ce9466301728082f8e9d99c90989ba8fe623bcf0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Text datasets.""" +"""Text datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,10 +31,12 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated DBPEDIA_URL = 'https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz' +@deprecated(None, 'See contrib/learn/README.md') def maybe_download_dbpedia(data_dir): """Download if DBpedia data is not present.""" train_path = os.path.join(data_dir, 'dbpedia_csv/train.csv') @@ -41,6 +48,7 @@ def maybe_download_dbpedia(data_dir): tfile.extractall(data_dir) +@deprecated(None, 'See contrib/learn/README.md') def load_dbpedia(size='small', test_with_fake_data=False): """Get DBpedia datasets from CSV files.""" if not test_with_fake_data: diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 4981750c94c7ac31e23b7a3f71ca30e3c9573a20..3e64595f312bcc2a2e8dcba589fb993a249b684b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""An estimator is a rule for calculating an estimate of a given quantity. +"""An estimator is a rule for calculating an estimate of a given quantity (deprecated). + +These classes are deprecated and replaced with `tf.estimator`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. # Estimators diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 15277415a1ce83dc1d4a334e60fe1933ba244df0..1f0e4663d060a3850e2002b27f809fde1db47e48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""sklearn cross-support.""" +"""sklearn cross-support (deprecated).""" from __future__ import absolute_import from __future__ import division @@ -132,6 +132,8 @@ class _TransformerMixin(): class NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. + USE OF THIS EXCEPTION IS DEPRECATED. + This class inherits from both ValueError and AttributeError to help with exception handling and backward compatibility. diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py index a02c726c74946d93b8e1726473db746220b00795..1fa58271e2b886cd143683a759145fd750791473 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow composable models used as building blocks for estimators.""" +"""TensorFlow composable models used as building blocks for estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,6 +39,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated class _ComposableModel(object): @@ -46,6 +52,7 @@ class _ComposableModel(object): _ComposableModel and its subclasses are not part of the public tf.learn API. """ + @deprecated(None, "Please use model_fns in tf.estimator.") def __init__(self, num_label_columns, optimizer, @@ -141,6 +148,10 @@ class _ComposableModel(object): class LinearComposableModel(_ComposableModel): """A _ComposableModel that implements linear regression. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ @@ -252,6 +263,10 @@ class LinearComposableModel(_ComposableModel): class DNNComposableModel(_ComposableModel): """A _ComposableModel that implements a DNN. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/constants.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py index fc69e810244a182b864be856e6720f8584f7aa65..d2548946bc77dea7c452d61c7e2b6e12c3d6239a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/constants.py +++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py @@ -13,9 +13,11 @@ # limitations under the License. # ============================================================================== -"""Constants regarding Estimators. +"""Constants regarding Estimators (deprecated). -This file is obsoleted in the move of Estimator to core. +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ from __future__ import absolute_import from __future__ import division @@ -25,6 +27,8 @@ from __future__ import print_function class ProblemType(object): """Enum-like values for the type of problem that the model solves. + THIS CLASS IS DEPRECATED. + These values are used when exporting the model to produce the appropriate signature function for serving. diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug.py b/tensorflow/contrib/learn/python/learn/estimators/debug.py index 9d5f6c2bf969d7c85d251bf1b06a0307a41b2297..24b067b7e38b12df3d1d0c49f626344217218571 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Debug estimators. +"""Debug estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Debug estimators are bias-only estimators that can be used for debugging and as simple baselines. @@ -118,6 +122,10 @@ def debug_model_fn(features, labels, mode, params, config=None): class DebugClassifier(estimator.Estimator): """A classifier for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -237,6 +245,10 @@ class DebugClassifier(estimator.Estimator): class DebugRegressor(estimator.Estimator): """A regressor for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index c17b41c0f767e19d9c3635a8f60347a49b297cfb..eabebb7e881558471c343c0573cc9a8f4a425312 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Deep Neural Network estimators.""" +"""Deep Neural Network estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -212,6 +217,10 @@ def _dnn_model_fn(features, labels, mode, params, config=None): class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -521,6 +530,10 @@ class DNNClassifier(estimator.Estimator): class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -796,6 +809,10 @@ class DNNRegressor(estimator.Estimator): class DNNEstimator(estimator.Estimator): """A Estimator for TensorFlow DNN models with user specified _Head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 726612235050def6e7addb503cc6646a25de0e42..3d85533d92d17095bae9a69f229171e1bf61ba10 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow estimators for Linear and DNN joined training models.""" +"""TensorFlow estimators for Linear and DNN joined training models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -372,6 +377,10 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): class DNNLinearCombinedEstimator(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -490,6 +499,10 @@ class DNNLinearCombinedEstimator(estimator.Estimator): class DNNLinearCombinedClassifier(estimator.Estimator): """A classifier for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -832,6 +845,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): class DNNLinearCombinedRegressor(estimator.Estimator): """A regressor for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 69440e823ef1ed2d739f28bc14587891f2de80bb..a703dc66e922d48ceb64edc2a979061b8e45db49 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for Dynamic RNNs.""" +"""Estimator for Dynamic RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -540,6 +545,12 @@ def _get_dynamic_rnn_model_fn( class DynamicRnnEstimator(estimator.Estimator): + """Dynamically unrolled RNN (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 4b63e08ab3372849309ee5d28d754de82e9632f4..7a026a15e4aeea0dde4ed9f7de053a757a0abb58 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base Estimator class.""" +"""Base Estimator class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -138,6 +143,7 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): return df.input_builder, df.get_feed_dict_fn() +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input_fn(input_fn): """Creates `FeatureColumn` objects for inputs defined by `input_fn`. @@ -158,6 +164,7 @@ def infer_real_valued_columns_from_input_fn(input_fn): return layers.infer_real_valued_columns(features) +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input(x): """Creates `FeatureColumn` objects for inputs defined by input `x`. @@ -389,6 +396,10 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Users should not instantiate or subclass this class. Instead, use an `Estimator`. """ @@ -399,6 +410,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): Remove this once launcher takes over config functionality _Config = run_config.RunConfig # pylint: disable=invalid-name + @deprecated(None, 'Please replace uses of any Estimator from tf.contrib.learn' + ' with an Estimator from tf.estimator.*') def __init__(self, model_dir=None, config=None): """Initializes a BaseEstimator instance. @@ -457,6 +470,20 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): make RunConfig immutable, and then return it without a copy. return copy.deepcopy(self._config) + @property + def model_fn(self): + """Returns the model_fn which is bound to self.params. + + Returns: + The model_fn with the following signature: + `def model_fn(features, labels, mode, metrics)` + """ + + def public_model_fn(features, labels, mode, config): + return self._call_model_fn(features, labels, mode, config=config) + + return public_model_fn + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), ('y', None), ('batch_size', None)) def fit(self, @@ -890,8 +917,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, if feed_fn: hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn)) if steps == 0: - logging.warning('evaluation steps are 0. If `input_fn` does not raise' - 'OutOfRangeError`, the evaluation will never stop.' + logging.warning('evaluation steps are 0. If `input_fn` does not raise ' + '`OutOfRangeError`, the evaluation will never stop. ' 'Use steps=None if intended.') if steps: hooks.append( @@ -1074,6 +1101,10 @@ def _identity_feature_engineering_fn(features, labels): class Estimator(BaseEstimator): """Estimator class is the basic TensorFlow model trainer/evaluator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ def __init__(self, @@ -1162,7 +1193,7 @@ class Estimator(BaseEstimator): self._feature_engineering_fn = ( feature_engineering_fn or _identity_feature_engineering_fn) - def _call_model_fn(self, features, labels, mode, metrics=None): + def _call_model_fn(self, features, labels, mode, metrics=None, config=None): """Calls model function with support of 2, 3 or 4 arguments. Args: @@ -1170,6 +1201,7 @@ class Estimator(BaseEstimator): labels: labels dict. mode: ModeKeys metrics: Dict of metrics. + config: RunConfig. Returns: A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a @@ -1186,7 +1218,10 @@ class Estimator(BaseEstimator): if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: - kwargs['config'] = self.config + if config: + kwargs['config'] = config + else: + kwargs['config'] = self.config if 'model_dir' in model_fn_args: kwargs['model_dir'] = self.model_dir model_fn_results = self._model_fn(features, labels, **kwargs) @@ -1458,8 +1493,14 @@ class Estimator(BaseEstimator): # For time of deprecation x,y from Estimator allow direct access. # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): - """Scikit learn wrapper for TensorFlow Learn Estimator.""" + """Scikit learn wrapper for TensorFlow Learn Estimator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please switch to the Estimator interface.') def __init__(self, estimator): self._estimator = estimator diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py index fd47710e3015de9ae6a453f98978b0ef8f88968c..e4c31396baf8271c49395a2b87b454dbc77177e2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utils for Estimator.""" +"""Utils for Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 9b124b2c19f16bbc9b2afeadb82a32006e1a0ae9..2b4b6eff39f4fc8a20a149edfc07d2f4f27a9bae 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Abstractions for the head(s) of a model. +"""Abstractions for the head(s) of a model (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -47,11 +52,16 @@ from tensorflow.python.summary import summary from tensorflow.python.training import training from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated class Head(object): """Interface for the head/top of a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Given logits (or output of a hidden layer), a Head knows how to compute predictions, loss, default metric and export signature. It is meant to, @@ -177,6 +187,7 @@ class Head(object): raise NotImplementedError("Calling an abstract method.") +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -216,6 +227,7 @@ def regression_head(label_name=None, link_fn=(link_fn if link_fn is not None else array_ops.identity)) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def poisson_regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -254,6 +266,7 @@ def poisson_regression_head(label_name=None, # TODO(zakaria): Consider adding a _RegressionHead for logistic_regression +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_class_head(n_classes, label_name=None, weight_column_name=None, @@ -335,6 +348,7 @@ def multi_class_head(n_classes, label_keys=label_keys) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def binary_svm_head( label_name=None, weight_column_name=None, @@ -370,6 +384,7 @@ def binary_svm_head( thresholds=thresholds) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_label_head(n_classes, label_name=None, weight_column_name=None, @@ -430,6 +445,7 @@ def multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def loss_only_head(loss_fn, head_name=None): """Creates a Head that contains only loss terms. @@ -447,6 +463,7 @@ def loss_only_head(loss_fn, head_name=None): return _LossOnlyHead(loss_fn, head_name=head_name) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. @@ -479,6 +496,7 @@ def multi_head(heads, loss_weights=None): return _MultiHead(heads, loss_merger=_weighted_loss_merger) +@deprecated(None, "Use 'lambda _: tf.no_op()'.") def no_op_train_fn(loss): del loss return control_flow_ops.no_op() diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 6d5da81b4c2087fb9c5307902e452a6220a17cd0..7c2d9bb0767cb979dae9c84b5342d129225677ed 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -362,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase): "auc_precision_recall": 0.166667, "auc_precision_recall/class0": 0, "auc_precision_recall/class1": 0., - "auc_precision_recall/class2": 0.49999, + "auc_precision_recall/class2": 1., "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -748,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, - "auc_precision_recall": 0.25, + "auc_precision_recall": 0.749999, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 8f9d6fc318a357853bdb8e3264f6691b410006b1..66ebcfd1d81904b9afe5be6bd1a648fe325e1e0b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API. +"""Implementation of k-means clustering on top of `Estimator` API (deprecated). This module is deprecated. Please use @{tf.contrib.factorization.KMeansClustering} instead of @@ -153,7 +153,12 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE RANDOM_INIT = clustering_ops.RANDOM_INIT diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py index b28835a809736a099ad2f08d127dc68d7977a3c1..584556992a0db2345e182e92c4a7f7582d3cd8dc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import random_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import flags from tensorflow.python.platform import test diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 37aa8b339622415d082933cdf66d2472a4119b48..70b70af98c51dcb991c19152607272673953ee2a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Linear Estimators.""" +"""Linear Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -238,8 +243,8 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" - with variable_scope.variable_op_scope( - features.values(), parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -305,6 +310,10 @@ class _SdcaUpdateWeightsHook(session_run_hook.SessionRunHook): class LinearClassifier(estimator.Estimator): """Linear classifier model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear model to classify instances into one of multiple possible classes. When number of possible classes is 2, this is binary classification. @@ -625,6 +634,10 @@ class LinearClassifier(estimator.Estimator): class LinearRegressor(estimator.Estimator): """Linear regressor model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear regression model to predict label value given observation of feature values. @@ -860,6 +873,10 @@ class LinearRegressor(estimator.Estimator): class LinearEstimator(estimator.Estimator): """Linear model with user specified head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a generalized linear model to predict label value given observation of feature values. diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py index fb339160d58e09d4ffd50090f2dbbcec08bebe47..3cbcc6e98de1c915c302617e4591c9baa33adeaf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logistic regression (aka binary classifier) class. +"""Logistic regression (aka binary classifier) class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This defines some useful basic metrics for using logistic regression to classify a binary event (0 vs 1). @@ -75,6 +79,10 @@ def LogisticRegressor( # pylint: disable=invalid-name feature_engineering_fn=None): """Builds a logistic regression Estimator for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This method provides a basic Estimator with some additional metrics for custom binary classification models, including AUC, precision/recall and accuracy. diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 99388f116b345bd038f2985606c6203011597ea2..f264248e44d9aa48f26ee32e36746bd4c3145a8d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for metric keys.""" +"""Enum for metric keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function class MetricKey(object): - """Metric key strings.""" + """Metric key strings (deprecated).""" + LOSS = "loss" AUC = "auc" AUC_PR = "auc_precision_recall" diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 44e6c7c52dac524a22e9099e33e2aef82f8fe7ba..dcb161180c99ce71195c820217e8bdaf79d70901 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Classes and methods related to model_fn.""" +"""Classes and methods related to model_fn (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -37,10 +42,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import session_run_hook +from tensorflow.python.util.deprecation import deprecated class ModeKeys(object): - """Standard names for model modes. + """Standard names for model modes (deprecated). + + THIS CLASS IS DEPRECATED. The following standard keys are defined: @@ -65,8 +73,16 @@ class ModelFnOps( 'output_alternatives', 'training_chief_hooks', 'training_hooks', 'scaffold', 'mode' ])): - """Ops returned from a model_fn.""" + """Ops returned from a model_fn. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'When switching to tf.estimator.Estimator, use ' + 'tf.estimator.EstimatorSpec. You can use the `estimator_spec`' + ' method to create an equivalent one.') def __new__(cls, mode, predictions=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py index f8d87b8914307a86eb2f46123a28ff11eb925eda..6fd2fc9d592cef4e44a640e2f27cb28b367d44d5 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for model prediction keys. +"""Enum for model prediction keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This file is obsoleted in the move of Estimator to core. """ @@ -22,6 +26,8 @@ from __future__ import print_function class PredictionKey(object): + """THIS CLASS IS DEPRECATED.""" + CLASSES = "classes" PROBABILITIES = "probabilities" LOGITS = "logits" diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index 2752bc2d90ee0f51b2c40ccc4d24a4eb21cff38f..215022e5d9e5d3cd5d6a96583b325b19a1719568 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common operations for RNN Estimators.""" +"""Common operations for RNN Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index fd90fd1cc6277e7d80287aefdbab6134dac7c0d5..14ee2ba6094760d52180d6de7763ea88b8ee98c8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Run Config.""" +"""Run Config (deprecated, use tf.estimator.RunConfig instead). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +34,12 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as core_run_config from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util.deprecation import deprecated # A list of the property names in RunConfig user allows to change. They will # not affect the execution framework, so when execution framework checks the -# `uid` of the RunConfig, it should be ingored. +# `uid` of the RunConfig, it should be ignored. _DEFAULT_UID_WHITE_LIST = [ 'tf_random_seed', 'save_summary_steps', @@ -47,6 +53,7 @@ _DEFAULT_UID_WHITE_LIST = [ class Environment(object): + """DEPRECATED CLASS.""" # For running general distributed training. CLOUD = 'cloud' # For running Google-internal distributed training. @@ -56,6 +63,7 @@ class Environment(object): class TaskType(object): + """DEPRECATED CLASS.""" MASTER = 'master' PS = 'ps' WORKER = 'worker' @@ -64,6 +72,8 @@ class TaskType(object): class ClusterConfig(object): """This class specifies the configurations for a distributed run. + THIS CLASS IS DEPRECATED. Use tf.estimator.RunConfig instead. + If you're using an `Estimator`, you should probably use the subclass RunConfig instead. """ @@ -211,10 +221,13 @@ class ClusterConfig(object): class RunConfig(ClusterConfig, core_run_config.RunConfig): """This class specifies the configurations for an `Estimator` run. - This class is the implementation of @{tf.estimator.RunConfig} interface. + This class is a deprecated implementation of @{tf.estimator.RunConfig} + interface. """ _USE_DEFAULT = 0 + @deprecated(None, 'When switching to tf.estimator.Estimator, use' + ' tf.estimator.RunConfig instead.') def __init__(self, master=None, num_cores=0, @@ -277,8 +290,16 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): Note - using this argument, it is easy to provide settings which break otherwise perfectly good models. Use with care. """ - super(RunConfig, self).__init__( - master=master, evaluation_master=evaluation_master) + # Neither parent class calls super().__init__(), so here we have to + # manually call their __init__() methods. + ClusterConfig.__init__( + self, master=master, evaluation_master=evaluation_master) + # For too long this code didn't call: + # core_run_config.RunConfig.__init__(self) + # so instead of breaking compatibility with that assumption, we + # just manually initialize this field: + self._train_distribute = None + self._device_fn = None gpu_options = config_pb2.GPUOptions( per_process_gpu_memory_fraction=gpu_memory_fraction) diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index 0cea35e219a4457417a161a3ac4ac4292fd24f53..de78c72c3ae3ef14f5f7c46b1d47f82e8266c7c6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for State Saving RNNs.""" +"""Estimator for State Saving RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -528,6 +533,12 @@ def _get_rnn_model_fn(cell_type, class StateSavingRnnEstimator(estimator.Estimator): + """RNN with static unrolling and state saving (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index 72920d73c0c92886e54f533ad7fe170fe27d9870..3459997baba16fc0d4045e50819ecdd0e7121657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support Vector Machine (SVM) Estimator.""" +"""Support Vector Machine (SVM) Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -36,6 +41,10 @@ def _as_iterable(preds, output): class SVM(estimator.Estimator): """Support Vector Machine (SVM) model for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Currently, only linear SVMs are supported. For the underlying optimization problem, the `SDCAOptimizer` is used. For performance and convergence tuning, the num_loss_partitions parameter passed to `SDCAOptimizer` (see `__init__()` diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py index a120bc6cc3975a3d4559d018c8aa74ff42a16d2d..71b5658dd174d2b47e33860844359f68e6768026 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py +++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorSignature class and utilities.""" +"""TensorSignature class and utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -33,6 +38,10 @@ class TensorSignature(collections.namedtuple( "TensorSignature", ["dtype", "shape", "is_sparse"])): """Signature of the `Tensor` object. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Useful to check compatibility of tensors. Example: diff --git a/tensorflow/contrib/learn/python/learn/estimators/test_data.py b/tensorflow/contrib/learn/python/learn/estimators/test_data.py index ed201bfc58f273e6587850032386c2686aea4148..e4b057b4f5a9e081c2d891bd9828ffc315e51e91 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/test_data.py +++ b/tensorflow/contrib/learn/python/learn/estimators/test_data.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test data utilities.""" +"""Test data utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 8f6cd39864b437f163dd7c1140dc88755ce98529..10881ca885599bc81386e15f814a2687d907f63b 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Evaluable` interface.""" +"""`Evaluable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,10 @@ import abc class Evaluable(object): """Interface for objects that are evaluatable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index bec976afd2719138117976381669ca3292360480..3744abd860e7f460133873eb534fd75887182f78 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experiment class collecting information needed for a single training run.""" +"""Experiment class collecting information for a single training run (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -25,7 +30,6 @@ import os import time from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import export_strategy @@ -118,6 +122,10 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): class Experiment(object): """Experiment is a class containing all information needed to train a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + After an experiment is created (by passing an Estimator and inputs for training and evaluation), an Experiment instance knows how to invoke training and eval loops in a sensible fashion for distributed training. @@ -125,16 +133,8 @@ class Experiment(object): # TODO(ispir): remove delay_workers_by_global_step and make global step based # waiting as only behavior. - @deprecated_args( - "2016-10-23", - "local_eval_frequency is deprecated as local_run will be renamed to " - "train_and_evaluate. Use min_eval_frequency and call train_and_evaluate " - "instead. Note, however, that the default for min_eval_frequency is 1, " - "meaning models will be evaluated every time a new checkpoint is " - "available. In contrast, the default for local_eval_frequency is None, " - "resulting in evaluation occurring only after training has completed. " - "min_eval_frequency is ignored when calling the deprecated local_run.", - "local_eval_frequency") + @deprecated(None, "Please switch to tf.estimator.train_and_evaluate. You will" + " also have to convert to a tf.estimator.Estimator.") def __init__(self, estimator, train_input_fn, @@ -152,7 +152,8 @@ class Experiment(object): export_strategies=None, train_steps_per_iteration=None, checkpoint_and_export=False, - saving_listeners=None): + saving_listeners=None, + check_interval_secs=5): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -190,8 +191,9 @@ class Experiment(object): number of steps between evaluations. Of course, evaluation does not occur if no new snapshot is available, hence, this is the minimum. If 0, the evaluation will only happen after training. - If None, defaults to 1, unless model_dir is on GCS, in which case the - default is 1000. + If None, defaults to 1. To avoid checking for new checkpoints too + frequent, the interval is further limited to be at least + check_interval_secs between checks. delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. export_strategies: Iterable of `ExportStrategy`s, or a single one, or @@ -215,7 +217,10 @@ class Experiment(object): saving_listeners: list of `CheckpointSaverListener` objects. Used by tf.estimator.Estimator for callbacks that run immediately before or after checkpoint savings. - + check_interval_secs: + Minimum time between subsequent checks for a new checkpoint. This + mostly applies if both min_eval_frequency and the time spent per + training step is low. Raises: ValueError: if `estimator` does not implement Estimator interface, or if export_strategies has the wrong type. @@ -261,13 +266,9 @@ class Experiment(object): self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._checkpoint_and_export = checkpoint_and_export self._saving_listeners = saving_listeners - # Using 1 on a non-cached file system requires a lot of overhead to - # read the checkpoint state file. This is particular bad on GCS, so - # we use a different default. This is a temporary band-aid, to be - # fixed holistically later (b/36498507). - default_min_eval_frequency = 1000 if _is_gcs(estimator.model_dir) else 1 self._min_eval_frequency = min_eval_frequency if ( - min_eval_frequency is not None) else default_min_eval_frequency + min_eval_frequency is not None) else 1 + self._check_interval_secs = check_interval_secs self._delay_workers_by_global_step = delay_workers_by_global_step self._train_monitors = train_monitors[:] if train_monitors else [] self._eval_hooks = eval_hooks[:] if eval_hooks else [] @@ -357,7 +358,7 @@ class Experiment(object): self._start_server() elif config.cluster_spec and config.master: raise ValueError( - "For distributed runtime, Experiment class only works with" + "For distributed runtime, Experiment class only works with " "tf.contrib.learn.RunConfig for now, but provided {}".format( type(config))) @@ -646,12 +647,19 @@ class Experiment(object): self._train_monitors += [saver_hook] else: if self._min_eval_frequency: + # Using low min_eval_frequency (default is 1) on a non-cached file + # system requires a lot of overhead to read the checkpoint state file. + # This is particular bad on GCS and CNS. See also b/36498507 for + # context. `check_interval_secs = 5` avoids polling a remote + # fileystem too often. + self._train_monitors += [ monitors.ValidationMonitor( input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._min_eval_frequency, + check_interval_secs=self._check_interval_secs, name=eval_dir_suffix, hooks=self._eval_hooks) ] @@ -928,7 +936,3 @@ def _new_attr_context(obj, attr): yield finally: setattr(obj, attr, saved) - - -def _is_gcs(model_dir): - return model_dir and model_dir.startswith("gs://") diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 545d7d8924c0c10544e6113e2968b7ae3d2090fc..d10927a0cdd5c67c8d2a8e569153235ee175ec4d 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -674,37 +674,11 @@ class ExperimentTest(test.TestCase): def test_min_eval_frequency_defaults(self): def dummy_model_fn(features, labels): # pylint: disable=unused-argument pass - - # The default value when model_dir is on GCS is 1000 - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, train_input_fn=None, eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 1000) - - # The default value when model_dir is not on GCS is 1 estimator = core_estimator.Estimator(dummy_model_fn, '/tmp/dummy') ex = experiment.Experiment( estimator, train_input_fn=None, eval_input_fn=None) self.assertEquals(ex._min_eval_frequency, 1) - # Make sure default not used when explicitly set - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, - min_eval_frequency=123, - train_input_fn=None, - eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 123) - - # Make sure default not used when explicitly set as 0 - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, - min_eval_frequency=0, - train_input_fn=None, - eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 0) - def test_continuous_train_and_eval(self): for est in self._estimators_for_tests(eval_dict={'global_step': 100}): if isinstance(est, core_estimator.Estimator): diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index 55a8b824312b89e0ac66513242191f4201ac212a..075cab536ecb5279e7e6f23abb0b70c75043a7ec 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""ExportStrategy class represents different flavors of model export.""" +"""ExportStrategy class represents different flavors of model export (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated __all__ = ['ExportStrategy'] @@ -30,6 +36,10 @@ class ExportStrategy( ['name', 'export_fn', 'strip_default_attrs'])): """A class representing a type of model export. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Typically constructed by a utility function specific to the exporter, such as `saved_model_export_utils.make_export_strategy()`. @@ -56,6 +66,8 @@ class ExportStrategy( forward compatibility of the resulting `SavedModel`. """ + @deprecated(None, 'Please switch to tf.estimator.train_and_evaluate, and use ' + 'tf.estimator.Exporter.') def __new__(cls, name, export_fn, strip_default_attrs=None): return super(ExportStrategy, cls).__new__( cls, name, export_fn, strip_default_attrs) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 98365c05f663e5d2a06703457fc5663d7135f7d9..a997fab723a16dddf150aa9397863605e4e77933 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level operations on graphs.""" +"""High level operations on graphs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -68,6 +73,7 @@ def clear_summary_writers(): return summary_io.SummaryWriterCache.clear() +@deprecated(None, 'Use `SummaryWriterCache.get` directly.') def get_summary_writer(logdir): """Returns single SummaryWriter per logdir in current run. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py index 06c3782a471537cf3879450e6bd20899a35d96ac..8b133a4440d8cbc19abca64f972791fc16ade6f8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tools to allow different io formats.""" +"""Tools to allow different io formats (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py index 7d666391cea3c0a52a2cb7e324c00d5f480710d5..e0a1948d95a727675dac8ff3ce9f55c35d5f8d8d 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Methods to allow dask.DataFrame.""" +"""Methods to allow dask.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.util.deprecation import deprecated + try: # pylint: disable=g-import-not-at-top import dask.dataframe as dd @@ -60,6 +67,7 @@ def _construct_dask_df_with_divisions(df): return dd.Series(merge(dsk, df.dask), name, df.name, divisions) +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_data(data): """Extract data from dask.Series or dask.DataFrame for predictors. @@ -81,6 +89,7 @@ def extract_dask_data(data): return data +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_labels(labels): """Extract data from dask.Series or dask.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 96be8b1bc402479d5611965f27abb197363cb939..c45b1d186471125776d6536112aebb66bb5ad558 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementations of different data feeders to provide data for TF trainer.""" +"""Implementations of different data feeders to provide data for TF trainer (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues. @@ -31,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels @@ -101,6 +107,7 @@ def _is_iterable(x): return hasattr(x, 'next') or hasattr(x, '__next__') +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_train_data_feeder(x, y, n_classes, @@ -188,6 +195,7 @@ def _batch_data(x, batch_size=None): yield np.matrix(chunk) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_predict_data_feeder(x, batch_size=None): """Returns an iterable for feeding into predict step. @@ -219,6 +227,7 @@ def setup_predict_data_feeder(x, batch_size=None): return [x] +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_processor_data_feeder(x): """Sets up processor iterable. @@ -233,6 +242,7 @@ def setup_processor_data_feeder(x): return x +@deprecated(None, 'Please convert numpy dtypes explicitly.') def check_array(array, dtype): """Checks array on dtype and converts it if different. @@ -275,8 +285,14 @@ def _check_dtype(dtype): class DataFeeder(object): - """Data feeder is an example class to sample data for TF trainer.""" + """Data feeder is an example class to sample data for TF trainer. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, x, y, @@ -563,6 +579,10 @@ class DataFeeder(object): class StreamingDataFeeder(DataFeeder): """Data feeder for TF trainer that reads data from iterator. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Streaming data feeder allows to read data as it comes it from disk or somewhere else. It's custom to have this iterators rotate infinetly over the dataset, to allow control of how much to learn on the trainer side. @@ -771,11 +791,16 @@ class StreamingDataFeeder(DataFeeder): class DaskDataFeeder(object): """Data feeder for that reads data from dask.Series and dask.DataFrame. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Numpy arrays can be serialized to disk and it's possible to do random seeks into them. DaskDataFeeder will remove requirement to have full dataset in the memory and still do random seeks for sampling of batches. """ + @deprecated(None, 'Please feed input to tf.data to support dask.') def __init__(self, x, y, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py index 82848be7df653dd60219317d28f233767746f544..1f439965daf956665bbedc919281df0ee07b5d62 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os.path import numpy as np import six from six.moves import xrange # pylint: disable=redefined-builtin @@ -26,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.learn.python.learn.learn_io import * from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io from tensorflow.python.platform import test # pylint: enable=wildcard-import @@ -35,6 +37,13 @@ class DataFeederTest(test.TestCase): # pylint: disable=undefined-variable """Tests for `DataFeeder`.""" + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), 'base_dir') + file_io.create_dir(self._base_dir) + + def tearDown(self): + file_io.delete_recursively(self._base_dir) + def _wrap_dict(self, data, prepend=''): return {prepend + '1': data, prepend + '2': data} @@ -45,14 +54,14 @@ class DataFeederTest(test.TestCase): def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data): feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) if isinstance(input_data, dict): - for k, v in list(feeder.input_dtype.items()): + for v in list(feeder.input_dtype.values()): self.assertEqual(expected_np_dtype, v) else: self.assertEqual(expected_np_dtype, feeder.input_dtype) with ops.Graph().as_default() as g, self.test_session(g): inp, _ = feeder.input_builder() if isinstance(inp, dict): - for k, v in list(inp.items()): + for v in list(inp.values()): self.assertEqual(expected_tf_dtype, v.dtype) else: self.assertEqual(expected_tf_dtype, inp.dtype) @@ -301,7 +310,10 @@ class DataFeederTest(test.TestCase): [0.60000002, 0.2]]) self.assertAllClose(feed_dict[out.name], [[0., 0., 1.], [0., 1., 0.]]) - def test_hdf5_data_feeder(self): + # TODO(rohanj): Fix this test by fixing data_feeder. Currently, h5py doesn't + # support permutation based indexing lookups (More documentation at + # http://docs.h5py.org/en/latest/high/dataset.html#fancy-indexing) + def DISABLED_test_hdf5_data_feeder(self): def func(df): inp, out = df.input_builder() @@ -314,11 +326,12 @@ class DataFeederTest(test.TestCase): import h5py # pylint: disable=g-import-not-at-top x = np.matrix([[1, 2], [3, 4]]) y = np.array([1, 2]) - h5f = h5py.File('test_hdf5.h5', 'w') + file_path = os.path.join(self._base_dir, 'test_hdf5.h5') + h5f = h5py.File(file_path, 'w') h5f.create_dataset('x', data=x) h5f.create_dataset('y', data=y) h5f.close() - h5f = h5py.File('test_hdf5.h5', 'r') + h5f = h5py.File(file_path, 'r') x = h5f['x'] y = h5f['y'] func(data_feeder.DataFeeder(x, y, n_classes=0, batch_size=3)) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py index 884faf8335e2a3ca1d27d2d93b4c817131648774..f8aaa0c9e3e5b589a6ad47678dba3dc38de7c471 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow generator of dict with numpy arrays.""" +"""Methods to allow generator of dict with numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,8 +28,10 @@ from types import FunctionType from types import GeneratorType from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.data.') def generator_input_fn(x, target_key=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 3a46c239688017f9204d2c6182a6f81cd325a417..9e816f54b6cf8dee84c6d62406ab3db700054d06 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to read data in the graph.""" +"""Methods to read data in the graph (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,11 +39,13 @@ from tensorflow.python.platform import gfile from tensorflow.python.summary import summary from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner +from tensorflow.python.util.deprecation import deprecated # Default name for key in the feature dict. KEY_FEATURE_NAME = '__key__' +@deprecated(None, 'Use tf.data.') def read_batch_examples(file_pattern, batch_size, reader, @@ -106,6 +113,7 @@ def read_batch_examples(file_pattern, return examples +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples(file_pattern, batch_size, reader, @@ -175,6 +183,7 @@ def read_keyed_batch_examples(file_pattern, seed=seed) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples_shared_queue(file_pattern, batch_size, reader, @@ -452,6 +461,7 @@ def _read_keyed_batch_examples_helper(file_pattern, return queued_examples_with_keys +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features(file_pattern, batch_size, features, @@ -540,6 +550,7 @@ def read_keyed_batch_features(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features_shared_queue(file_pattern, batch_size, features, @@ -620,6 +631,7 @@ def read_keyed_batch_features_shared_queue(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def queue_parsed_features(parsed_features, keys=None, feature_queue_capacity=100, @@ -742,6 +754,7 @@ def queue_parsed_features(parsed_features, return dequeued_keys, dequeued_parsed_features +@deprecated(None, 'Use tf.data.') def read_batch_features(file_pattern, batch_size, features, @@ -821,6 +834,7 @@ def read_batch_features(file_pattern, return features +@deprecated(None, 'Use tf.data.') def read_batch_record_features(file_pattern, batch_size, features, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 692438807fbd7febb156d4db73b5d3deba6c987d..29552d24f1eaa0d85a99c8b09f69d007e7e4fe9f 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow dict of numpy arrays.""" +"""Methods to allow dict of numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_numpy_input_fn +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Use tf.estimator.inputs.numpy_input_fn.') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index ede7558eafa9237dc63aa95a62e599c5e9755822..b4ef055f5ae484ec704ad42efcf2c00c4a7a4f56 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -13,13 +13,19 @@ # limitations under the License. # ============================================================================== -"""Methods to allow pandas.DataFrame.""" +"""Methods to allow pandas.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn as core_pandas_input_fn +from tensorflow.python.util.deprecation import deprecated try: # pylint: disable=g-import-not-at-top @@ -47,6 +53,7 @@ PANDAS_DTYPES = { } +@deprecated(None, 'Please use tf.estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, @@ -66,6 +73,7 @@ def pandas_input_fn(x, target_column=target_column) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_data(data): """Extract data from pandas.DataFrame for predictors. @@ -96,6 +104,7 @@ def extract_pandas_data(data): 'float, or bool. Found: ' + ', '.join(error_report)) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_matrix(data): """Extracts numpy matrix from pandas DataFrame. @@ -111,6 +120,7 @@ def extract_pandas_matrix(data): return data.as_matrix() +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_labels(labels): """Extract data from pandas.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 2af723a0d64822e81fa0fbeb106ab812de6ab4e8..d719a3e488b9905ef7903e21d90dbaae0449735c 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Runs an Experiment.""" +"""Runs an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config as run_c from tensorflow.contrib.learn.python.learn.experiment import Experiment from tensorflow.contrib.training.python.training import hparam as hparam_lib from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(xiejw): Refactor the learn_runner to make code reusable. @@ -99,6 +105,7 @@ def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False): return wrapped_experiment_fn +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def run(experiment_fn, output_dir=None, schedule=None, run_config=None, hparams=None): """Make and run an experiment. @@ -218,6 +225,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, return _execute_schedule(experiment, schedule) +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def tune(experiment_fn, tuner): """Tune an experiment with hyper-parameters. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py index 7d9b1c7716f0ab1f2274ca53406175240b613027..ba2d067787c1dfd4e4820ecc916f1053e9f3cf60 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities to run and tune an Experiment. +"""Utilities to run and tune an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@run @@tune diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index 6440bc204b8e339ff51311dcc87b36f556b94092..97220365d5dddb82b602369f06bea021a86d584f 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The metric spec class to flexibly connect models and metrics.""" +"""The metric spec class to flexibly connect models and metrics (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ import six from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated def _assert_named_args(sentinel): @@ -223,6 +229,10 @@ def _adapt_metric_fn( class MetricSpec(object): """MetricSpec connects a model to metric functions. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + The MetricSpec class contains all information necessary to connect the output of a `model_fn` to the metrics (usually, streaming metrics) that are used in evaluation. @@ -284,6 +294,7 @@ class MetricSpec(object): """ + @deprecated(None, 'Use tf.estimator.EstimatorSpec.eval_metric_ops.') def __init__(self, metric_fn, prediction_key=None, diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py index 4283240d018c949bb35aeb12032d2ee8b75884a5..bd4bbf9f8c9ad7e8a0fc06d8c0dc24672536c158 100644 --- a/tensorflow/contrib/learn/python/learn/models.py +++ b/tensorflow/contrib/learn/python/learn/models.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Various high level TF models.""" +"""Various high level TF models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -28,8 +33,10 @@ from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using a tf.estimator.LinearRegressor') def linear_regression_zero_init(x, y): """Linear regression subgraph with zero-value initial weights and bias. @@ -43,6 +50,7 @@ def linear_regression_zero_init(x, y): return linear_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.LinearClassifier') def logistic_regression_zero_init(x, y): """Logistic regression subgraph with zero-value initial weights and bias. @@ -56,6 +64,7 @@ def logistic_regression_zero_init(x, y): return logistic_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.') def linear_regression(x, y, init_mean=None, init_stddev=1.0): """Creates linear regression TensorFlow subgraph. @@ -107,6 +116,7 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0): return losses_ops.mean_squared_error_regressor(x, y, weights, bias) +@deprecated(None, 'Consider using a class from tf.estimator.') def logistic_regression(x, y, class_weight=None, @@ -203,6 +213,7 @@ def _reverse_seq(input_seq, lengths): return result +@deprecated(None, 'Please consider `tf.nn.bidirectional_dynamic_rnn`.') def bidirectional_rnn(cell_fw, cell_bw, inputs, @@ -283,6 +294,7 @@ def bidirectional_rnn(cell_fw, # End of TensorFlow 0.7 +@deprecated(None, 'Please consider tensorflow/tensor2tensor.') def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn, bidirectional, target_predictor_fn, sequence_length, initial_state, attn_length, attn_size, attn_vec_size): diff --git a/tensorflow/contrib/learn/python/learn/monitored_session.py b/tensorflow/contrib/learn/python/learn/monitored_session.py index 22602e9f69d972505d83a66a6f9183b5e4d15c44..ac0433f1775feeed2ec3cf49291da01500bef01b 100644 --- a/tensorflow/contrib/learn/python/learn/monitored_session.py +++ b/tensorflow/contrib/learn/python/learn/monitored_session.py @@ -13,7 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A wrapper of Session API which runs hooks.""" +"""A wrapper of Session API which runs hooks (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 51381a7427c919592b8e818c4b46dba974992610..77f7c73d5412d40b338eaff4cf04d99fd0892723 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monitors instrument the training process. +"""Monitors instrument the training process (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@get_default_monitors @@BaseMonitor @@ -59,6 +63,10 @@ from tensorflow.python.util import tf_inspect class BaseMonitor(object): """Base class for Monitors. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Defines basic interfaces of Monitors. Monitors can either be run on all workers or, more commonly, restricted to run exclusively on the elected chief worker. @@ -229,6 +237,10 @@ def _extract_output(outputs, request): class EveryN(BaseMonitor): """Base class for monitors that execute callbacks every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This class adds three new callbacks: - every_n_step_begin - every_n_step_end @@ -418,6 +430,10 @@ class StopAtStep(BaseMonitor): class PrintTensor(EveryN): """Prints given tensors every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This is an `EveryN` monitor and has consistent semantic for `every_n` and `first_n`. @@ -455,9 +471,12 @@ class PrintTensor(EveryN): class LoggingTrainable(EveryN): """Writes trainable variable values into log every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Write the tensors in trainable variables `every_n` steps, starting with the `first_n`th step. - """ def __init__(self, scope=None, every_n=100, first_n=1): @@ -493,7 +512,12 @@ class LoggingTrainable(EveryN): class SummarySaver(EveryN): - """Saves summaries every N steps.""" + """Saves summaries every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, summary_op, @@ -554,6 +578,10 @@ class SummarySaver(EveryN): class ValidationMonitor(EveryN): """Runs evaluation of a given estimator, at most every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note that the evaluation is done based on the saved checkpoint, which will usually be older than the current step. @@ -573,7 +601,8 @@ class ValidationMonitor(EveryN): early_stopping_rounds=None, early_stopping_metric="loss", early_stopping_metric_minimize=True, - name=None): + name=None, + check_interval_secs=5): """Initializes a ValidationMonitor. Args: @@ -600,6 +629,9 @@ class ValidationMonitor(EveryN): loss metrics like mean squared error, and False for performance metrics like accuracy. name: See `BaseEstimator.evaluate`. + check_interval_secs: Only check for new checkpoint if at least + `check_interval_secs` have passed. Ignore if None. Default is 5 secs. + Raises: ValueError: If both x and input_fn are provided. @@ -626,6 +658,8 @@ class ValidationMonitor(EveryN): self._early_stopped = False self._latest_path = None self._latest_path_step = None + self._last_checkpoint_check_time = None + self._check_interval_secs = check_interval_secs @property def early_stopped(self): @@ -690,6 +724,16 @@ class ValidationMonitor(EveryN): # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") + current_time = time.time() + if (self._check_interval_secs is not None and + self._last_checkpoint_check_time is not None and + current_time - self._last_checkpoint_check_time <= + self._check_interval_secs): + logging.debug( + "Skipping evaluation since less than %d seconds have passed since " + "last check for a new checkpoint.", self._check_interval_secs) + return False + self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: @@ -740,6 +784,10 @@ class ValidationMonitor(EveryN): class CaptureVariable(EveryN): """Captures a variable's values into a collection. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This monitor is useful for unit testing. You should exercise caution when using this monitor in production, since it never discards values. @@ -778,6 +826,7 @@ class CaptureVariable(EveryN): self._var_values[step] = _extract_output(outputs, self._var_name) +@deprecation.deprecated(None, "Use tf.train.MonitoredTrainingSession.") def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, @@ -812,6 +861,10 @@ def get_default_monitors(loss_op=None, class GraphDump(BaseMonitor): """Dumps almost all tensors in the graph at every step. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note, this is very expensive, prefer `PrintTensor` in production. """ @@ -901,7 +954,12 @@ class GraphDump(BaseMonitor): class ExportMonitor(EveryN): - """Monitor that exports Estimator every N steps.""" + """Monitor that exports Estimator every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ @deprecation.deprecated("2017-03-25", "ExportMonitor is deprecated. Please pass an " @@ -1024,7 +1082,12 @@ class ExportMonitor(EveryN): class CheckpointSaver(BaseMonitor): - """Saves checkpoints every N steps or N seconds.""" + """Saves checkpoints every N steps or N seconds. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, checkpoint_dir, @@ -1109,7 +1172,12 @@ class CheckpointSaver(BaseMonitor): class StepCounter(EveryN): - """Steps per second monitor.""" + """Steps per second monitor. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): super(StepCounter, self).__init__(every_n_steps=every_n_steps) @@ -1149,6 +1217,10 @@ class NanLossDuringTrainingError(RuntimeError): class NanLoss(EveryN): """NaN Loss monitor. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Monitors loss and stops training if loss is NaN. Can either fail with exception or just stop training. """ diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index b2b24776c60183113a5f936dd276ff312d6d0079..5c34d0ddb01f3bcdc407e6926e7c5b73be1863b4 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -385,7 +385,11 @@ class MonitorsTest(test.TestCase): estimator.evaluate.return_value = validation_outputs monitor = learn.monitors.ValidationMonitor( - x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=2) + x=constant_op.constant(2.0), + every_n_steps=0, + early_stopping_rounds=2, + check_interval_secs=None) + self._assert_validation_monitor(monitor) monitor.set_estimator(estimator) with ops.Graph().as_default() as g, self.test_session(g): diff --git a/tensorflow/contrib/learn/python/learn/ops/__init__.py b/tensorflow/contrib/learn/python/learn/ops/__init__.py index 33962e34cc685ce2c830a7bbfd1b5c626bcd8b31..efb1f47cf5bb2dcd0fb37b7b85cd8f170d56e4d1 100644 --- a/tensorflow/contrib/learn/python/learn/ops/__init__.py +++ b/tensorflow/contrib/learn/python/learn/ops/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Various TensorFlow Ops.""" +"""Various TensorFlow Ops (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py index fa3b7323e343371e986b763d30a8a44620894549..8f9811cf251ae0af1e0055a56e1358c2771b1367 100644 --- a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops to work with embeddings. +"""TensorFlow Ops to work with embeddings (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Note: categorical variables are handled via embeddings in many cases. For example, in case of words. @@ -57,7 +61,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'): ids = ops.convert_to_tensor(ids) shape = array_ops_.shape(ids) ids_flat = array_ops_.reshape( - ids, math_ops.reduce_prod(shape, keep_dims=True)) + ids, math_ops.reduce_prod(shape, keepdims=True)) embeds_flat = nn.embedding_lookup(params, ids_flat, name) embed_shape = array_ops_.concat([shape, [-1]], 0) embeds = array_ops_.reshape(embeds_flat, embed_shape) diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py index b040ab3bb6c516158589a8e30d56fff1f7728951..92976d1539c7ddc226b81f903beee82b798ec8db 100644 --- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for loss computation.""" +"""TensorFlow Ops for loss computation (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py index 45727faab4362abeab18f77861353eb53976023a..aa37cb4a76e2a6157bf077d327248353bd516472 100644 --- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for Sequence to Sequence models.""" +"""TensorFlow Ops for Sequence to Sequence models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,8 +31,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): """Returns predictions and loss for sequence of predictions. @@ -57,6 +64,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): return array_ops.stack(predictions, axis=1), loss +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): """Processes inputs for Sequence to Sequence models. @@ -87,6 +95,7 @@ def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): return in_x, in_y, out_y +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): """RNN Decoder that creates training and sampling sub-graphs. @@ -123,6 +132,7 @@ def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): return outputs, states, sampling_outputs, sampling_states +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_seq2seq(encoder_inputs, decoder_inputs, encoder_cell, diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py index 7bcc177d4ea0ab57f092d68888a72de2b2fd5edc..e8c6e1acf80f0791421bee59aff30e67bccb44b2 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Preprocessing tools useful for building models.""" +"""Preprocessing tools useful for building models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py index 154739d497ec1029026eaca1e93b37cd225f1050..faba3b2025e8abb51d1989c3fafbd5e711d6559b 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements preprocessing transformers for categorical variables.""" +"""Implements preprocessing transformers for categorical variables (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,8 @@ from __future__ import print_function import math import numpy as np +from tensorflow.python.util.deprecation import deprecated + # pylint: disable=g-bad-import-order from . import categorical_vocabulary from ..learn_io.data_feeder import setup_processor_data_feeder @@ -31,10 +38,16 @@ from ..learn_io.data_feeder import setup_processor_data_feeder class CategoricalProcessor(object): """Maps documents to sequences of word ids. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + As a common convention, Nan values are handled as unknown tokens. Both float('nan') and np.nan are accepted. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data for sequence ' + 'processing.') def __init__(self, min_frequency=0, share=False, vocabularies=None): """Initializes a CategoricalProcessor instance. diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py index 5709955c49fba50ca4a299a443a2902bbd9c6b23..3ac370a6ab4423846e810900514445ad5269b680 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""Categorical vocabulary classes to map categories to indexes. +"""Categorical vocabulary classes to map categories to indexes (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Can be used for categorical variables, sparse variables and words. """ @@ -25,14 +29,21 @@ from __future__ import print_function import collections import six +from tensorflow.python.util.deprecation import deprecated + class CategoricalVocabulary(object): """Categorical variables vocabulary class. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Accumulates and provides mapping from classes to indexes. Can be easily used for words. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, unknown_token="", support_reverse=True): self._unknown_token = unknown_token self._mapping = {unknown_token: 0} diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/text.py b/tensorflow/contrib/learn/python/learn/preprocessing/text.py index 3af2074c2a46f0258c04111fff0235ba8309625e..f2b6776be7789a9433bfe41eb9354b74347059ec 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/text.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/text.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements a number of text preprocessing utilities.""" +"""Implements a number of text preprocessing utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -24,6 +29,7 @@ import numpy as np import six from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated from .categorical_vocabulary import CategoricalVocabulary # pylint: disable=g-bad-import-order @@ -38,6 +44,7 @@ TOKENIZER_RE = re.compile(r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+", re.UNICODE) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def tokenizer(iterator): """Tokenizer generator. @@ -51,9 +58,16 @@ def tokenizer(iterator): yield TOKENIZER_RE.findall(value) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') class ByteProcessor(object): - """Maps documents into sequence of ids for bytes.""" + """Maps documents into sequence of ids for bytes. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length): self.max_document_length = max_document_length @@ -108,8 +122,14 @@ class ByteProcessor(object): class VocabularyProcessor(object): - """Maps documents to sequences of word ids.""" + """Maps documents to sequences of word ids. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length, min_frequency=0, diff --git a/tensorflow/contrib/learn/python/learn/session_run_hook.py b/tensorflow/contrib/learn/python/learn/session_run_hook.py index a8ba2be97206f2b974d256ad2c62c21a4e3e55d8..87edc9b720bdb3edcd5f2dcd1662d14da53c51cf 100644 --- a/tensorflow/contrib/learn/python/learn/session_run_hook.py +++ b/tensorflow/contrib/learn/python/learn/session_run_hook.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This file is deprecated. Use tensorflow.python.training.session_run_hook.""" +"""This file is deprecated. Use `tensorflow.python.training.session_run_hook`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py index 919d415c302b8ec17202aad34ff0bee69bfee2c7..d663cf5fb79c428b0e70d66b0f1305f0559a05c9 100644 --- a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py +++ b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wrapper for a Session-like object that handles threads and recovery. +"""Wrapper for a Session-like object that handles threads and recovery (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. Based on an original design of Illia Polosukhin. """ diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 429b6040be21d8cbe1f2bba58090366552fdfbe7..a1a3f20dcd8cb5ff7baa559ac41d5e5c40780511 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Trainable` interface.""" +"""`Trainable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,8 @@ import abc class Trainable(object): """Interface for objects that are trainable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py index 48978d0ac34cec2b18e6794dcf3b260bc3b683c4..66d8dc6fd43b383919a16515bc96be492a253bf6 100644 --- a/tensorflow/contrib/learn/python/learn/utils/__init__.py +++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Learn Utils.""" +"""TensorFlow Learn Utils (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index cb34cb1d26b6812c7f3f39e9f965615de5a8ef07..3eacac7a3d3dcff4d39025fdee88e16e385b1b84 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -13,14 +13,18 @@ # limitations under the License. # ============================================================================== -"""Export utilities.""" +"""Export utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -32,6 +36,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver +from tensorflow.python.training import training_util @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py index 226915987a4934626066b12810f579ae675107b2..916aecbea88b10bbef316ffb89d4c4d89667cb29 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -r"""System for specifying garbage collection (GC) of path based data. +r"""System for specifying garbage collection (GC) of path based data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This framework allows for GC of data specified by path names, for example files on disk. gc.Path objects each represent a single item stored at a path and may @@ -73,10 +77,12 @@ import os from tensorflow.python.platform import gfile from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated Path = collections.namedtuple('Path', 'path export_version') +@deprecated(None, 'Please implement your own file management or use Saver.') def largest_export_versions(n): """Creates a filter that keeps the largest n export versions. @@ -97,6 +103,7 @@ def largest_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def one_of_every_n_export_versions(n): """Creates a filter that keeps one of every n export versions. @@ -128,6 +135,7 @@ def one_of_every_n_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def mod_export_version(n): """Creates a filter that keeps every export that is a multiple of n. @@ -146,6 +154,7 @@ def mod_export_version(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def union(lf, rf): """Creates a filter that keeps the union of two filters. @@ -163,6 +172,7 @@ def union(lf, rf): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def negation(f): """Negate a filter. @@ -179,6 +189,7 @@ def negation(f): return keep +@deprecated(None, 'Please implement your own file name management.') def get_paths(base_dir, parser): """Gets a list of Paths in a given directory. diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py index b2521933e524e7ec24d73d4b5171f33e507dd88c..b92eb9fea8b7ccea56c781df74dcfa1cc5508e48 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for creating input_fns. +"""Utilities for creating input_fns (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Contents of this file are moved to tensorflow/python/estimator/export.py. InputFnOps is renamed to ServingInputReceiver. @@ -32,13 +36,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.util.deprecation import deprecated class InputFnOps(collections.namedtuple('InputFnOps', ['features', 'labels', 'default_inputs'])): - """A return type for an input_fn. + """A return type for an input_fn (deprecated). + + THIS CLASS IS DEPRECATED. Please use tf.estimator.export.ServingInputReceiver + instead. This return type is currently only supported for serving input_fn. Training and eval input_fn should return a `(features, labels)` tuple. @@ -56,6 +64,8 @@ class InputFnOps(collections.namedtuple('InputFnOps', """ +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_parsing_serving_input_receiver_fn.') def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): """Build an input_fn appropriate for serving, expecting fed tf.Examples. @@ -84,6 +94,8 @@ def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): return input_fn +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_raw_serving_input_receiver_fn.') def build_default_serving_input_fn(features, default_batch_size=None): """Build an input_fn appropriate for serving, expecting feature Tensors. diff --git a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py index 6a63fb545a56e6040b0b0c3bbb6a17cd96925895..6dbaa15f8391b0044be8e30ca191753beb88db93 100644 --- a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py +++ b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A simple script for inspect checkpoint files.""" +"""A simple script for inspect checkpoint files (deprecated).""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 1593380007b2799fb1d17e92408ab19a7b47fe1e..c7cdb4131215c388412407a008113de13bdd0934 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities supporting export to SavedModel. +"""Utilities supporting export to SavedModel (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Some contents of this file are moved to tensorflow/python/estimator/export.py: @@ -52,8 +56,9 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator from tensorflow.python.training import saver - from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated + # A key for use in the input_alternatives dict indicating the default input. # This is the input that will be expected when a serving request does not @@ -77,6 +82,7 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_standardized_signature_def(input_tensors, output_tensors, problem_type): """Build a SignatureDef using problem type and input and output Tensors. @@ -156,6 +162,7 @@ def _is_regression_problem(problem_type, input_tensors, output_tensors): len(input_tensors) == 1 and len(output_tensors) == 1) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_input_alternatives(input_ops): """Obtain all input alternatives using the input_fn output and heuristics.""" input_alternatives = {} @@ -181,6 +188,7 @@ def get_input_alternatives(input_ops): return input_alternatives, features +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): """Obtain all output alternatives using the model_fn output and heuristics. @@ -246,6 +254,7 @@ def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): sorted(output_alternatives.keys()))) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_all_signature_defs(input_alternatives, output_alternatives, actual_default_output_alternative_key): """Build `SignatureDef`s from all pairs of input and output alternatives.""" @@ -279,6 +288,7 @@ def build_all_signature_defs(input_alternatives, output_alternatives, MAX_DIRECTORY_CREATION_ATTEMPTS = 10 +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. @@ -317,6 +327,7 @@ def get_timestamped_export_dir(export_dir_base): '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_temp_export_dir(timestamped_export_dir): """Builds a directory name based on the argument but starting with 'temp-'. @@ -344,6 +355,7 @@ def _export_version_parser(path): return path._replace(export_version=int(filename)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_most_recent_export(export_dir_base): """Locate the most recent SavedModel export in a directory of many exports. @@ -363,6 +375,7 @@ def get_most_recent_export(export_dir_base): return next(iter(results or []), None) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def garbage_collect_exports(export_dir_base, exports_to_keep): """Deletes older exports, retaining only a given number of the most recent. @@ -387,6 +400,7 @@ def garbage_collect_exports(export_dir_base, exports_to_keep): logging.warn('Can not delete %s recursively: %s', p.path, e) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, @@ -400,7 +414,7 @@ def make_export_strategy(serving_input_fn, `InputFnOps`. default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from ${tf.estimator.Estimator} + Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -438,7 +452,7 @@ def make_export_strategy(serving_input_fn, The string path to the exported directory. Raises: - ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified. """ if isinstance(estimator, core_estimator.Estimator): @@ -469,6 +483,8 @@ def make_export_strategy(serving_input_fn, return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) +@deprecated(None, + 'Use tf.estimator.export.build_parsing_serving_input_receiver_fn') def make_parsing_export_strategy(feature_columns, default_output_alternative_key=None, assets_extra=None, @@ -487,7 +503,7 @@ def make_parsing_export_strategy(feature_columns, that must be provided at serving time (excluding labels!). default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. - Must be `None` if the estimator inherits from ${tf.estimator.Estimator} + Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination @@ -555,8 +571,14 @@ def _default_compare_fn(curr_best_eval_result, cand_eval_result): class BestModelSelector(object): - """A helper that keeps track of export selection candidates.""" + """A helper that keeps track of export selection candidates. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def __init__(self, event_file_pattern=None, compare_fn=None): """Constructor of this class. @@ -622,6 +644,7 @@ class BestModelSelector(object): return best_eval_result +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_best_model_export_strategy( serving_input_fn, exports_to_keep=1, @@ -707,6 +730,7 @@ def make_best_model_export_strategy( # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def extend_export_strategy(base_export_strategy, post_export_fn, post_export_name=None): @@ -741,7 +765,7 @@ def extend_export_strategy(base_export_strategy, The string path to the SavedModel indicated by post_export_fn. Raises: - ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. RuntimeError: If unable to create temporary or final export directory. diff --git a/tensorflow/contrib/legacy_seq2seq/BUILD b/tensorflow/contrib/legacy_seq2seq/BUILD index 1fa55132b1fc0cd3367ca2eb331b6870edc30c3b..8c2c4fd29c0502d4199f27a65e4827b2db973c3d 100644 --- a/tensorflow/contrib/legacy_seq2seq/BUILD +++ b/tensorflow/contrib/legacy_seq2seq/BUILD @@ -60,15 +60,3 @@ cuda_py_tests( ], tags = ["noasan"], # times out b/63678675 ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD index df96402a4ffd51840f77d58d8066487030362340..4dccb9be7cd2e603edcf10c020cc0ee1675f518a 100644 --- a/tensorflow/contrib/libsvm/BUILD +++ b/tensorflow/contrib/libsvm/BUILD @@ -88,15 +88,3 @@ tf_py_test( "//tensorflow/python:platform_test", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 208e7bc69be76680868c766bc99429eea5870c80..2c5fa7af89bccefb84c7b0e0c0a628e3ce737706 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -43,14 +43,40 @@ cuda_py_test( ], ) -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], +cuda_py_test( + name = "linear_operator_block_diag_test", + size = "medium", + srcs = ["python/kernel_tests/linear_operator_block_diag_test.py"], + additional_deps = [ + ":linalg_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], + shard_count = 5, + tags = ["noasan"], +) + +cuda_py_test( + name = "linear_operator_kronecker_test", + size = "medium", + srcs = ["python/kernel_tests/linear_operator_kronecker_test.py"], + additional_deps = [ + ":linalg_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], + shard_count = 8, + tags = ["noasan"], ) diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 4720692c3384ba1bede1f486c1b1e0e69d10a63a..38bd66b13f79ee918acbdc2d33a6367e4e34d4b1 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -17,10 +17,12 @@ See the @{$python/contrib.linalg} guide. @@LinearOperator +@@LinearOperatorBlockDiag @@LinearOperatorDiag @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@LinearOperatorFullMatrix +@@LinearOperatorKronecker @@LinearOperatorLowerTriangular @@LinearOperatorLowRankUpdate @@LinearOperatorComposition @@ -34,6 +36,8 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * +from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import * +from tensorflow.contrib.linalg.python.ops.linear_operator_kronecker import * from tensorflow.python.ops.linalg.linear_operator import * from tensorflow.python.ops.linalg.linear_operator_composition import * from tensorflow.python.ops.linalg.linear_operator_diag import * @@ -45,4 +49,5 @@ from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.python.util.all_util import remove_undocumented + remove_undocumented(__name__) diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7407ede11409a47f4d9db96ad5b5d801ef1625d --- /dev/null +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.linalg.python.ops import linear_operator_block_diag as block_diag +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib +from tensorflow.python.ops.linalg import linear_operator_test_util +from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.platform import test + +linalg = linalg_lib +random_seed.set_random_seed(23) +rng = np.random.RandomState(0) + + +def _block_diag_dense(expected_shape, blocks): + """Convert a list of blocks, into a dense block diagonal matrix.""" + rows = [] + num_cols = 0 + for block in blocks: + # Get the batch shape for the block. + batch_row_shape = array_ops.shape(block)[:-1] + + zeros_to_pad_before_shape = array_ops.concat( + [batch_row_shape, [num_cols]], axis=-1) + zeros_to_pad_before = array_ops.zeros( + shape=zeros_to_pad_before_shape, dtype=block.dtype) + num_cols += array_ops.shape(block)[-1] + zeros_to_pad_after_shape = array_ops.concat( + [batch_row_shape, [expected_shape[-2] - num_cols]], axis=-1) + zeros_to_pad_after = array_ops.zeros( + zeros_to_pad_after_shape, dtype=block.dtype) + + rows.append(array_ops.concat( + [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) + + return array_ops.concat(rows, axis=-2) + + +class SquareLinearOperatorBlockDiagTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + def setUp(self): + # Increase from 1e-6 to 1e-4 + self._atol[dtypes.float32] = 1e-4 + self._atol[dtypes.complex64] = 1e-4 + self._rtol[dtypes.float32] = 1e-4 + self._rtol[dtypes.complex64] = 1e-4 + + @property + def _operator_build_infos(self): + build_info = linear_operator_test_util.OperatorBuildInfo + return [ + build_info((0, 0)), + build_info((1, 1)), + build_info((1, 3, 3)), + build_info((5, 5), blocks=[(2, 2), (3, 3)]), + build_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]), + build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]), + ] + + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) + expected_blocks = ( + build_info.__dict__["blocks"] if "blocks" in build_info.__dict__ + else [shape]) + matrices = [ + linear_operator_test_util.random_positive_definite_matrix( + block_shape, dtype, force_well_conditioned=True) + for block_shape in expected_blocks + ] + + if use_placeholder: + matrices_ph = [ + array_ops.placeholder(dtype=dtype) for _ in expected_blocks + ] + # Evaluate here because (i) you cannot feed a tensor, and (ii) + # values are random and we want the same value used for both mat and + # feed_dict. + matrices = self.evaluate(matrices) + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + m_ph, is_square=True) for m_ph in matrices_ph], + is_square=True) + feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} + else: + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix( + m, is_square=True) for m in matrices]) + feed_dict = None + # Should be auto-set. + self.assertTrue(operator.is_square) + + # Broadcast the shapes. + expected_shape = list(build_info.shape) + + matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) + + block_diag_dense = _block_diag_dense(expected_shape, matrices) + + if not use_placeholder: + block_diag_dense.set_shape( + expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]]) + + return operator, block_diag_dense, feed_dict + + def test_is_x_flags(self): + # Matrix with two positive eigenvalues, 1, and 1. + # The matrix values do not effect auto-setting of the flags. + matrix = [[1., 0.], [1., 1.]] + operator = block_diag.LinearOperatorBlockDiag( + [linalg.LinearOperatorFullMatrix(matrix)], + is_positive_definite=True, + is_non_singular=True, + is_self_adjoint=False) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertFalse(operator.is_self_adjoint) + + def test_is_non_singular_auto_set(self): + # Matrix with two positive eigenvalues, 11 and 8. + # The matrix values do not effect auto-setting of the flags. + matrix = [[11., 0.], [1., 8.]] + operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) + operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) + + operator = block_diag.LinearOperatorBlockDiag( + [operator_1, operator_2], + is_positive_definite=False, # No reason it HAS to be False... + is_non_singular=None) + self.assertFalse(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + + with self.assertRaisesRegexp(ValueError, "always non-singular"): + block_diag.LinearOperatorBlockDiag( + [operator_1, operator_2], is_non_singular=False) + + def test_name(self): + matrix = [[11., 0.], [1., 8.]] + operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left") + operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right") + + operator = block_diag.LinearOperatorBlockDiag([operator_1, operator_2]) + + self.assertEqual("left_ds_right", operator.name) + + def test_different_dtypes_raises(self): + operators = [ + linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)), + linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32)) + ] + with self.assertRaisesRegexp(TypeError, "same dtype"): + block_diag.LinearOperatorBlockDiag(operators) + + def test_non_square_operator_raises(self): + operators = [ + linalg.LinearOperatorFullMatrix(rng.rand(3, 4), is_square=False), + linalg.LinearOperatorFullMatrix(rng.rand(3, 3)) + ] + with self.assertRaisesRegexp(ValueError, "square matrices"): + block_diag.LinearOperatorBlockDiag(operators) + + def test_empty_operators_raises(self): + with self.assertRaisesRegexp(ValueError, "non-empty"): + block_diag.LinearOperatorBlockDiag([]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6574da22a188c7aa25ad8426522d0b446af8f5f3 --- /dev/null +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_kronecker_test.py @@ -0,0 +1,194 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.linalg.python.ops import linear_operator_kronecker as kronecker +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib +from tensorflow.python.ops.linalg import linear_operator_test_util +from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.platform import test + +linalg = linalg_lib +random_seed.set_random_seed(23) +rng = np.random.RandomState(0) + + +def _kronecker_dense(factors): + """Convert a list of factors, into a dense Kronecker product.""" + product = factors[0] + for factor in factors[1:]: + product = product[..., array_ops.newaxis, :, array_ops.newaxis] + factor_to_mul = factor[..., array_ops.newaxis, :, array_ops.newaxis, :] + product *= factor_to_mul + product = array_ops.reshape( + product, + shape=array_ops.concat( + [array_ops.shape(product)[:-4], + [array_ops.shape(product)[-4] * array_ops.shape(product)[-3], + array_ops.shape(product)[-2] * array_ops.shape(product)[-1]] + ], axis=0)) + + return product + + +class KroneckerDenseTest(test.TestCase): + + def testKroneckerDenseMatrix(self): + x = ops.convert_to_tensor([[2., 3.], [1., 2.]], dtype=dtypes.float32) + y = ops.convert_to_tensor([[1., 2.], [5., -1.]], dtype=dtypes.float32) + # From explicitly writing out the kronecker product of x and y. + z = ops.convert_to_tensor([ + [2., 4., 3., 6.], + [10., -2., 15., -3.], + [1., 2., 2., 4.], + [5., -1., 10., -2.]], dtype=dtypes.float32) + # From explicitly writing out the kronecker product of y and x. + w = ops.convert_to_tensor([ + [2., 3., 4., 6.], + [1., 2., 2., 4.], + [10., 15., -2., -3.], + [5., 10., -1., -2.]], dtype=dtypes.float32) + + with self.test_session(): + self.assertAllClose(_kronecker_dense([x, y]).eval(), z.eval()) + self.assertAllClose(_kronecker_dense([y, x]).eval(), w.eval()) + + +class SquareLinearOperatorKroneckerTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + def setUp(self): + # Increase from 1e-6 to 1e-4 + self._atol[dtypes.float32] = 1e-4 + self._atol[dtypes.complex64] = 1e-4 + self._rtol[dtypes.float32] = 1e-4 + self._rtol[dtypes.complex64] = 1e-4 + + @property + def _operator_build_infos(self): + build_info = linear_operator_test_util.OperatorBuildInfo + return [ + build_info((1, 1), factors=[(1, 1), (1, 1)]), + build_info((8, 8), factors=[(2, 2), (2, 2), (2, 2)]), + build_info((12, 12), factors=[(2, 2), (3, 3), (2, 2)]), + build_info((1, 3, 3), factors=[(1, 1), (1, 3, 3)]), + build_info((3, 6, 6), factors=[(3, 1, 1), (1, 2, 2), (1, 3, 3)]), + ] + + def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder): + shape = list(build_info.shape) + expected_factors = build_info.__dict__["factors"] + matrices = [ + linear_operator_test_util.random_positive_definite_matrix( + block_shape, dtype, force_well_conditioned=True) + for block_shape in expected_factors + ] + + if use_placeholder: + matrices_ph = [ + array_ops.placeholder(dtype=dtype) for _ in expected_factors + ] + # Evaluate here because (i) you cannot feed a tensor, and (ii) + # values are random and we want the same value used for both mat and + # feed_dict. + matrices = self.evaluate(matrices) + operator = kronecker.LinearOperatorKronecker( + [linalg.LinearOperatorFullMatrix( + m_ph, is_square=True) for m_ph in matrices_ph], + is_square=True) + feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)} + else: + operator = kronecker.LinearOperatorKronecker( + [linalg.LinearOperatorFullMatrix( + m, is_square=True) for m in matrices]) + feed_dict = None + # Should be auto-set. + self.assertTrue(operator.is_square) + + matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices) + + kronecker_dense = _kronecker_dense(matrices) + + if not use_placeholder: + kronecker_dense.set_shape(shape) + + return operator, kronecker_dense, feed_dict + + def test_is_x_flags(self): + # Matrix with two positive eigenvalues, 1, and 1. + # The matrix values do not effect auto-setting of the flags. + matrix = [[1., 0.], [1., 1.]] + operator = kronecker.LinearOperatorKronecker( + [linalg.LinearOperatorFullMatrix(matrix), + linalg.LinearOperatorFullMatrix(matrix)], + is_positive_definite=True, + is_non_singular=True, + is_self_adjoint=False) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertFalse(operator.is_self_adjoint) + + def test_is_non_singular_auto_set(self): + # Matrix with two positive eigenvalues, 11 and 8. + # The matrix values do not effect auto-setting of the flags. + matrix = [[11., 0.], [1., 8.]] + operator_1 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) + operator_2 = linalg.LinearOperatorFullMatrix(matrix, is_non_singular=True) + + operator = kronecker.LinearOperatorKronecker( + [operator_1, operator_2], + is_positive_definite=False, # No reason it HAS to be False... + is_non_singular=None) + self.assertFalse(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + + with self.assertRaisesRegexp(ValueError, "always non-singular"): + kronecker.LinearOperatorKronecker( + [operator_1, operator_2], is_non_singular=False) + + def test_name(self): + matrix = [[11., 0.], [1., 8.]] + operator_1 = linalg.LinearOperatorFullMatrix(matrix, name="left") + operator_2 = linalg.LinearOperatorFullMatrix(matrix, name="right") + + operator = kronecker.LinearOperatorKronecker([operator_1, operator_2]) + + self.assertEqual("left_x_right", operator.name) + + def test_different_dtypes_raises(self): + operators = [ + linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)), + linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3).astype(np.float32)) + ] + with self.assertRaisesRegexp(TypeError, "same dtype"): + kronecker.LinearOperatorKronecker(operators) + + def test_empty_or_one_operators_raises(self): + with self.assertRaisesRegexp(ValueError, ">=1 operators"): + kronecker.LinearOperatorKronecker([]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3af66c92b59dd030d4b2a829ab733eec6cf0c1 --- /dev/null +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_block_diag.py @@ -0,0 +1,370 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Create a Block Diagonal operator from one or more `LinearOperators`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_util + + +class LinearOperatorBlockDiag(linear_operator.LinearOperator): + """Combines one or more `LinearOperators` in to a Block Diagonal matrix. + + This operator combines one or more linear operators `[op1,...,opJ]`, + building a new `LinearOperator`, whose underlying matrix representation is + square and has each operator `opi` on the main diagonal, and zero's elsewhere. + + #### Shape compatibility + + If `opj` acts like a [batch] square matrix `Aj`, then `op_combined` acts like + the [batch] square matrix formed by having each matrix `Aj` on the main + diagonal. + + + Each `opj` is required to represent a square matrix, and hence will have + shape `batch_shape_j + [M_j, M_j]`. + + If `opj` has shape `batch_shape_j + [M_j, M_j]`, then the combined operator + has shape `broadcast_batch_shape + [sum M_j, sum M_j]`, where + `broadcast_batch_shape` is the mutual broadcast of `batch_shape_j`, + `j = 1,...,J`, assuming the intermediate batch shapes broadcast. + Even if the combined shape is well defined, the combined operator's + methods may fail due to lack of broadcasting ability in the defining + operators' methods. + + ```python + # Create a 4 x 4 linear operator combined of two 2 x 2 operators. + operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) + operator_2 = LinearOperatorFullMatrix([[1., 0.], [0., 1.]]) + operator = LinearOperatorBlockDiag([operator_1, operator_2]) + + operator.to_dense() + ==> [[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]] + + operator.shape + ==> [4, 4] + + operator.log_abs_determinant() + ==> scalar Tensor + + x1 = ... # Shape [2, 2] Tensor + x2 = ... # Shape [2, 2] Tensor + x = tf.concat([x1, x2], 0) # Shape [2, 4] Tensor + operator.matmul(x) + ==> tf.concat([operator_1.matmul(x1), operator_2.matmul(x2)]) + + # Create a [2, 3] batch of 4 x 4 linear operators. + matrix_44 = tf.random_normal(shape=[2, 3, 4, 4]) + operator_44 = LinearOperatorFullMatrix(matrix) + + # Create a [1, 3] batch of 5 x 5 linear operators. + matrix_55 = tf.random_normal(shape=[1, 3, 5, 5]) + operator_55 = LinearOperatorFullMatrix(matrix_55) + + # Combine to create a [2, 3] batch of 9 x 9 operators. + operator_99 = LinearOperatorBlockDiag([operator_44, operator_55]) + + # Create a shape [2, 3, 9] vector. + x = tf.random_normal(shape=[2, 3, 9]) + operator_99.matmul(x) + ==> Shape [2, 3, 9] Tensor + ``` + + #### Performance + + The performance of `LinearOperatorBlockDiag` on any operation is equal to + the sum of the individual operators' operations. + + + #### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint, positive_definite, square`. + These have the following meaning: + + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. + """ + + def __init__(self, + operators, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None, + is_square=True, + name=None): + r"""Initialize a `LinearOperatorBlockDiag`. + + `LinearOperatorBlockDiag` is initialized with a list of operators + `[op_1,...,op_J]`. + + Args: + operators: Iterable of `LinearOperator` objects, each with + the same `dtype` and composable shape. + is_non_singular: Expect that this operator is non-singular. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices + is_square: Expect that this operator acts like square [batch] matrices. + This is true by default, and will raise a `ValueError` otherwise. + name: A name for this `LinearOperator`. Default is the individual + operators names joined with `_o_`. + + Raises: + TypeError: If all operators do not have the same `dtype`. + ValueError: If `operators` is empty or are non-square. + """ + # Validate operators. + check_ops.assert_proper_iterable(operators) + operators = list(operators) + if not operators: + raise ValueError( + "Expected a non-empty list of operators. Found: %s" % operators) + self._operators = operators + + # Validate dtype. + dtype = operators[0].dtype + for operator in operators: + if operator.dtype != dtype: + name_type = (str((o.name, o.dtype)) for o in operators) + raise TypeError( + "Expected all operators to have the same dtype. Found %s" + % " ".join(name_type)) + + # Auto-set and check hints. + if all(operator.is_non_singular for operator in operators): + if is_non_singular is False: + raise ValueError( + "The direct sum of non-singular operators is always non-singular.") + is_non_singular = True + + if all(operator.is_self_adjoint for operator in operators): + if is_self_adjoint is False: + raise ValueError( + "The direct sum of self-adjoint operators is always self-adjoint.") + is_self_adjoint = True + + if all(operator.is_positive_definite for operator in operators): + if is_positive_definite is False: + raise ValueError( + "The direct sum of positive definite operators is always " + "positive definite.") + is_positive_definite = True + + if not (is_square and all(operator.is_square for operator in operators)): + raise ValueError( + "Can only represent a block diagonal of square matrices.") + + # Initialization. + graph_parents = [] + for operator in operators: + graph_parents.extend(operator.graph_parents) + + if name is None: + # Using ds to mean direct sum. + name = "_ds_".join(operator.name for operator in operators) + with ops.name_scope(name, values=graph_parents): + super(LinearOperatorBlockDiag, self).__init__( + dtype=dtype, + graph_parents=graph_parents, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=True, + name=name) + + @property + def operators(self): + return self._operators + + def _shape(self): + # Get final matrix shape. + domain_dimension = self.operators[0].domain_dimension + range_dimension = self.operators[0].range_dimension + for operator in self.operators[1:]: + domain_dimension += operator.domain_dimension + range_dimension += operator.range_dimension + + matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension]) + + # Get broadcast batch shape. + # broadcast_shape checks for compatibility. + batch_shape = self.operators[0].batch_shape + for operator in self.operators[1:]: + batch_shape = common_shapes.broadcast_shape( + batch_shape, operator.batch_shape) + + return batch_shape.concatenate(matrix_shape) + + def _shape_tensor(self): + # Avoid messy broadcasting if possible. + if self.shape.is_fully_defined(): + return ops.convert_to_tensor( + self.shape.as_list(), dtype=dtypes.int32, name="shape") + + domain_dimension = self.operators[0].domain_dimension_tensor() + range_dimension = self.operators[0].range_dimension_tensor() + for operator in self.operators[1:]: + domain_dimension += operator.domain_dimension_tensor() + range_dimension += operator.range_dimension_tensor() + + matrix_shape = array_ops.stack([domain_dimension, range_dimension]) + + # Dummy Tensor of zeros. Will never be materialized. + zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor()) + for operator in self.operators[1:]: + zeros += array_ops.zeros(shape=operator.batch_shape_tensor()) + batch_shape = array_ops.shape(zeros) + + return array_ops.concat((batch_shape, matrix_shape), 0) + + def _matmul(self, x, adjoint=False, adjoint_arg=False): + split_dim = -1 if adjoint_arg else -2 + # Split input by rows normally, and otherwise columns. + split_x = self._split_input_into_blocks(x, axis=split_dim) + + result_list = [] + for index, operator in enumerate(self.operators): + result_list += [operator.matmul( + split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] + result_list = linear_operator_util.broadcast_matrix_batch_dims( + result_list) + return array_ops.concat(result_list, axis=-2) + + def _determinant(self): + result = self.operators[0].determinant() + for operator in self.operators[1:]: + result *= operator.determinant() + return result + + def _log_abs_determinant(self): + result = self.operators[0].log_abs_determinant() + for operator in self.operators[1:]: + result += operator.log_abs_determinant() + return result + + def _solve(self, rhs, adjoint=False, adjoint_arg=False): + split_dim = -1 if adjoint_arg else -2 + # Split input by rows normally, and otherwise columns. + split_rhs = self._split_input_into_blocks(rhs, axis=split_dim) + + solution_list = [] + for index, operator in enumerate(self.operators): + solution_list += [operator.solve( + split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)] + + solution_list = linear_operator_util.broadcast_matrix_batch_dims( + solution_list) + return array_ops.concat(solution_list, axis=-2) + + def _diag_part(self): + diag_list = [] + for operator in self.operators: + # Extend the axis for broadcasting. + diag_list += [operator.diag_part()[..., array_ops.newaxis]] + diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list) + diagonal = array_ops.concat(diag_list, axis=-2) + return array_ops.squeeze(diagonal, axis=-1) + + def _trace(self): + result = self.operators[0].trace() + for operator in self.operators[1:]: + result += operator.trace() + return result + + def _to_dense(self): + num_cols = 0 + rows = [] + broadcasted_blocks = [operator.to_dense() for operator in self.operators] + broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims( + broadcasted_blocks) + for block in broadcasted_blocks: + batch_row_shape = array_ops.shape(block)[:-1] + + zeros_to_pad_before_shape = array_ops.concat( + [batch_row_shape, [num_cols]], axis=-1) + zeros_to_pad_before = array_ops.zeros( + shape=zeros_to_pad_before_shape, dtype=block.dtype) + num_cols += array_ops.shape(block)[-1] + zeros_to_pad_after_shape = array_ops.concat( + [batch_row_shape, + [self.domain_dimension_tensor() - num_cols]], axis=-1) + zeros_to_pad_after = array_ops.zeros( + shape=zeros_to_pad_after_shape, dtype=block.dtype) + + rows.append(array_ops.concat( + [zeros_to_pad_before, block, zeros_to_pad_after], axis=-1)) + + mat = array_ops.concat(rows, axis=-2) + mat.set_shape(self.shape) + return mat + + def _assert_non_singular(self): + return control_flow_ops.group([ + operator.assert_non_singular() for operator in self.operators]) + + def _assert_self_adjoint(self): + return control_flow_ops.group([ + operator.assert_self_adjoint() for operator in self.operators]) + + def _assert_positive_definite(self): + return control_flow_ops.group([ + operator.assert_positive_definite() for operator in self.operators]) + + def _split_input_into_blocks(self, x, axis=-1): + """Split `x` into blocks matching `operators`'s `domain_dimension`. + + Specifically, if we have a block diagonal matrix, with block sizes + `[M_j, M_j] j = 1..J`, this method splits `x` on `axis` into `J` + tensors, whose shape at `axis` is `M_j`. + + Args: + x: `Tensor`. `x` is split into `J` tensors. + axis: Python `Integer` representing the axis to split `x` on. + + Returns: + A list of `Tensor`s. + """ + block_sizes = [] + if self.shape.is_fully_defined(): + for operator in self.operators: + block_sizes += [operator.domain_dimension.value] + else: + for operator in self.operators: + block_sizes += [operator.domain_dimension_tensor()] + + return array_ops.split(x, block_sizes, axis=axis) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py b/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py new file mode 100644 index 0000000000000000000000000000000000000000..79080d194f59b7ebce045ab3e3d262ca948d9391 --- /dev/null +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_kronecker.py @@ -0,0 +1,560 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Construct the Kronecker product of one or more `LinearOperators`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg_impl as linalg +from tensorflow.python.ops.linalg import linear_operator + + +def _vec(x): + """Stacks column of matrix to form a single column.""" + return array_ops.reshape( + array_ops.matrix_transpose(x), + array_ops.concat( + [array_ops.shape(x)[:-2], [-1]], axis=0)) + + +def _unvec_by(y, num_col): + """Unstack vector to form a matrix, with a specified amount of columns.""" + return array_ops.matrix_transpose( + array_ops.reshape( + y, + array_ops.concat( + [array_ops.shape(y)[:-1], [num_col, -1]], axis=0))) + + +def _rotate_last_dim(x, rotate_right=False): + """Rotate the last dimension either left or right.""" + ndims = array_ops.rank(x) + if rotate_right: + transpose_perm = array_ops.concat( + [[ndims - 1], math_ops.range(0, ndims - 1)], axis=0) + else: + transpose_perm = array_ops.concat( + [math_ops.range(1, ndims), [0]], axis=0) + return array_ops.transpose(x, transpose_perm) + + +class LinearOperatorKronecker(linear_operator.LinearOperator): + """Kronecker product between two `LinearOperators`. + + This operator composes one or more linear operators `[op1,...,opJ]`, + building a new `LinearOperator` representing the Kronecker product: + `op1 x op2 x .. opJ` (we omit parentheses as the Kronecker product is + associative). + + If `opj` has shape `batch_shape_j` + [M_j, N_j`, then the composed operator + will have shape equal to `broadcast_batch_shape + [prod M_j, prod N_j]`, + where the product is over all operators. + + ```python + # Create a 4 x 4 linear operator composed of two 2 x 2 operators. + operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]]) + operator_2 = LinearOperatorFullMatrix([[1., 0.], [2., 1.]]) + operator = LinearOperatorKronecker([operator_1, operator_2]) + + operator.to_dense() + ==> [[1., 2., 0., 0.], + [3., 4., 0., 0.], + [2., 4., 1., 2.], + [6., 8., 3., 4.]] + + operator.shape + ==> [4, 4] + + operator.log_abs_determinant() + ==> scalar Tensor + + x = ... Shape [4, 2] Tensor + operator.matmul(x) + ==> Shape [4, 2] Tensor + + # Create a [2, 3] batch of 4 x 5 linear operators. + matrix_45 = tf.random_normal(shape=[2, 3, 4, 5]) + operator_45 = LinearOperatorFullMatrix(matrix) + + # Create a [2, 3] batch of 5 x 6 linear operators. + matrix_56 = tf.random_normal(shape=[2, 3, 5, 6]) + operator_56 = LinearOperatorFullMatrix(matrix_56) + + # Compose to create a [2, 3] batch of 20 x 30 operators. + operator_large = LinearOperatorKronecker([operator_45, operator_56]) + + # Create a shape [2, 3, 20, 2] vector. + x = tf.random_normal(shape=[2, 3, 6, 2]) + operator_large.matmul(x) + ==> Shape [2, 3, 30, 2] Tensor + ``` + + #### Performance + + The performance of `LinearOperatorKronecker` on any operation is equal to + the sum of the individual operators' operations. + + #### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint, positive_definite, square`. + These have the following meaning: + + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. + """ + + def __init__(self, + operators, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None, + is_square=None, + name=None): + r"""Initialize a `LinearOperatorKronecker`. + + `LinearOperatorKronecker` is initialized with a list of operators + `[op_1,...,op_J]`. + + Args: + operators: Iterable of `LinearOperator` objects, each with + the same `dtype` and composable shape, representing the Kronecker + factors. + is_non_singular: Expect that this operator is non-singular. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. + is_positive_definite: Expect that this operator is positive definite, + meaning the quadratic form `x^H A x` has positive real part for all + nonzero `x`. Note that we do not require the operator to be + self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix\ + #Extension_for_non_symmetric_matrices + is_square: Expect that this operator acts like square [batch] matrices. + name: A name for this `LinearOperator`. Default is the individual + operators names joined with `_x_`. + + Raises: + TypeError: If all operators do not have the same `dtype`. + ValueError: If `operators` is empty. + """ + # Validate operators. + check_ops.assert_proper_iterable(operators) + operators = list(operators) + if not operators: + raise ValueError( + "Expected a list of >=1 operators. Found: %s" % operators) + self._operators = operators + + # Validate dtype. + dtype = operators[0].dtype + for operator in operators: + if operator.dtype != dtype: + name_type = (str((o.name, o.dtype)) for o in operators) + raise TypeError( + "Expected all operators to have the same dtype. Found %s" + % " ".join(name_type)) + + # Auto-set and check hints. + # A Kronecker product is invertible, if and only if all factors are + # invertible. + if all(operator.is_non_singular for operator in operators): + if is_non_singular is False: + raise ValueError( + "The Kronecker product of non-singular operators is always " + "non-singular.") + is_non_singular = True + + if all(operator.is_self_adjoint for operator in operators): + if is_self_adjoint is False: + raise ValueError( + "The Kronecker product of self-adjoint operators is always " + "self-adjoint.") + is_self_adjoint = True + + # The eigenvalues of a Kronecker product are equal to the products of eigen + # values of the corresponding factors. + if all(operator.is_positive_definite for operator in operators): + if is_positive_definite is False: + raise ValueError("The Kronecker product of positive-definite operators " + "is always positive-definite.") + is_positive_definite = True + + # Initialization. + graph_parents = [] + for operator in operators: + graph_parents.extend(operator.graph_parents) + + if name is None: + name = operators[0].name + for operator in operators[1:]: + name += "_x_" + operator.name + with ops.name_scope(name, values=graph_parents): + super(LinearOperatorKronecker, self).__init__( + dtype=dtype, + graph_parents=graph_parents, + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + is_square=is_square, + name=name) + + @property + def operators(self): + return self._operators + + def _shape(self): + # Get final matrix shape. + domain_dimension = self.operators[0].domain_dimension + for operator in self.operators[1:]: + domain_dimension *= operator.domain_dimension + + range_dimension = self.operators[0].range_dimension + for operator in self.operators[1:]: + range_dimension *= operator.range_dimension + + matrix_shape = tensor_shape.TensorShape([ + range_dimension, domain_dimension]) + + # Get broadcast batch shape. + # broadcast_shape checks for compatibility. + batch_shape = self.operators[0].batch_shape + for operator in self.operators[1:]: + batch_shape = common_shapes.broadcast_shape( + batch_shape, operator.batch_shape) + + return batch_shape.concatenate(matrix_shape) + + def _shape_tensor(self): + domain_dimension = self.operators[0].domain_dimension_tensor() + for operator in self.operators[1:]: + domain_dimension *= operator.domain_dimension_tensor() + + range_dimension = self.operators[0].range_dimension_tensor() + for operator in self.operators[1:]: + range_dimension *= operator.range_dimension_tensor() + + matrix_shape = [range_dimension, domain_dimension] + + # Get broadcast batch shape. + # broadcast_shape checks for compatibility. + batch_shape = self.operators[0].batch_shape_tensor() + for operator in self.operators[1:]: + batch_shape = array_ops.broadcast_dynamic_shape( + batch_shape, operator.batch_shape_tensor()) + + return array_ops.concat((batch_shape, matrix_shape), 0) + + def _matmul(self, x, adjoint=False, adjoint_arg=False): + # Here we heavily rely on Roth's column Lemma [1]: + # (A x B) * vec X = vec BXA^T, + # where vec stacks all the columns of the matrix under each other. In our + # case, x represents a batch of vec X (i.e. we think of x as a batch of + # column vectors, rather than a matrix). Each member of the batch can be + # reshaped to a matrix (hence we get a batch of matrices). + # We can iteratively apply this lemma by noting that if B is a Kronecker + # product, then we can apply the lemma again. + + # [1] W. E. Roth, "On direct product matrices," + # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468, + # 1934 + + # Efficiency + + # Naively doing the Kronecker product, by calculating the dense matrix and + # applying it will can take cubic time in the size of domain_dimension + # (assuming a square matrix). The other issue is that calculating the dense + # matrix can be prohibitively expensive, in that it can take a large amount + # of memory. + # + # This implementation avoids this memory blow up by only computing matmuls + # with the factors. In this way, we don't have to realize the dense matrix. + # In terms of complexity, if we have Kronecker Factors of size: + # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we + # have as input a [N, M] matrix, the naive approach would take O(N^2 M). + # With this approach (ignoring reshaping of tensors and transposes for now), + # the time complexity can be O(M * (\sum n_i) * N). There is also the + # benefit of batched multiplication (In this example, the batch size is + # roughly M * N) so this can be much faster. However, not factored in are + # the costs of the several transposing of tensors, which can affect cache + # behavior. + + # Below we document the shape manipulation for adjoint=False, + # adjoint_arg=False, but the general case of different adjoints is still + # handled. + + if adjoint_arg: + x = linalg.adjoint(x) + + # Always add a batch dimension to enable broadcasting to work. + batch_shape = array_ops.concat( + [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) + x += array_ops.zeros(batch_shape, dtype=x.dtype.base_dtype) + + # x has shape [B, R, C], where B represent some number of batch dimensions, + # R represents the number of rows, and C represents the number of columns. + # In order to apply Roth's column lemma, we need to operate on a batch of + # column vectors, so we reshape into a batch of column vectors. We put it + # at the front to ensure that broadcasting between operators to the batch + # dimensions B still works. + output = _rotate_last_dim(x, rotate_right=True) + + # Also expand the shape to be [A, C, B, R]. The first dimension will be + # used to accumulate dimensions from each operator matmul. + output = output[array_ops.newaxis, ...] + + # In this loop, A is going to refer to the value of the accumulated + # dimension. A = 1 at the start, and will end up being self.range_dimension. + # V will refer to the last dimension. V = R at the start, and will end up + # being 1 in the end. + for operator in self.operators[:-1]: + # Reshape output from [A, C, B, V] to be + # [A, C, B, V / op.domain_dimension, op.domain_dimension] + if adjoint: + operator_dimension = operator.range_dimension_tensor() + else: + operator_dimension = operator.domain_dimension_tensor() + + output = _unvec_by(output, operator_dimension) + + # We are computing (XA^T) = (AX^T)^T. + # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], + # which is being converted to: + # [A, C, B, V / op.domain_dimension, op.range_dimension] + output = array_ops.matrix_transpose(output) + output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False) + output = array_ops.matrix_transpose(output) + # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] + output = _rotate_last_dim(output, rotate_right=False) + output = _vec(output) + output = _rotate_last_dim(output, rotate_right=True) + + # After the loop, we will have + # A = self.range_dimension / op[-1].range_dimension + # V = op[-1].domain_dimension + + # We convert that using matvec to get: + # [A, C, B, op[-1].range_dimension] + output = self.operators[-1].matvec(output, adjoint=adjoint) + # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] + output = _rotate_last_dim(output, rotate_right=False) + output = _vec(output) + output = _rotate_last_dim(output, rotate_right=False) + + if x.shape.is_fully_defined(): + column_dim = x.shape[-1] + broadcast_batch_shape = common_shapes.broadcast_shape( + x.shape[:-2], self.batch_shape) + if adjoint: + matrix_dimensions = [self.domain_dimension, column_dim] + else: + matrix_dimensions = [self.range_dimension, column_dim] + + print("x: ", x) + print("bathc_shape:", self.batch_shape) + print("self.shape:", self.shape) + print("output: ", output) + output.set_shape(broadcast_batch_shape.concatenate( + matrix_dimensions)) + + return output + + def _determinant(self): + # Note that we have |X1 x X2| = |X1| ** n * |X2| ** m, where X1 is an m x m + # matrix, and X2 is an n x n matrix. We can iteratively apply this property + # to get the determinant of |X1 x X2 x X3 ...|. If T is the product of the + # domain dimension of all operators, then we have: + # |X1 x X2 x X3 ...| = + # |X1| ** (T / m) * |X2 x X3 ... | ** m = + # |X1| ** (T / m) * |X2| ** (m * (T / m) / n) * ... = + # |X1| ** (T / m) * |X2| ** (T / n) * | X3 x X4... | ** (m * n) + # And by doing induction we have product(|X_i| ** (T / dim(X_i))). + total = self.domain_dimension_tensor() + determinant = 1. + for operator in self.operators: + determinant *= operator.determinant() ** math_ops.cast( + total / operator.domain_dimension_tensor(), + dtype=operator.dtype) + return determinant + + def _log_abs_determinant(self): + # This will be sum((total / dim(x_i)) * log |X_i|) + total = self.domain_dimension_tensor() + log_abs_det = 0. + for operator in self.operators: + log_abs_det += operator.log_abs_determinant() * math_ops.cast( + total / operator.domain_dimension_tensor(), + dtype=operator.dtype) + return log_abs_det + + def _trace(self): + # tr(A x B) = tr(A) * tr(B) + trace = 1. + for operator in self.operators: + trace *= operator.trace() + return trace + + def _solve(self, rhs, adjoint=False, adjoint_arg=False): + # Here we follow the same use of Roth's column lemma as in `matmul`, with + # the key difference that we replace all `matmul` instances with `solve`. + # This follows from the property that inv(A x B) = inv(A) x inv(B). + + # Below we document the shape manipulation for adjoint=False, + # adjoint_arg=False, but the general case of different adjoints is still + # handled. + + if adjoint_arg: + rhs = linalg.adjoint(rhs) + + # Always add a batch dimension to enable broadcasting to work. + batch_shape = array_ops.concat( + [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0) + rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype) + + # rhs has shape [B, R, C], where B represent some number of batch + # dimensions, + # R represents the number of rows, and C represents the number of columns. + # In order to apply Roth's column lemma, we need to operate on a batch of + # column vectors, so we reshape into a batch of column vectors. We put it + # at the front to ensure that broadcasting between operators to the batch + # dimensions B still works. + output = _rotate_last_dim(rhs, rotate_right=True) + + # Also expand the shape to be [A, C, B, R]. The first dimension will be + # used to accumulate dimensions from each operator matmul. + output = output[array_ops.newaxis, ...] + + # In this loop, A is going to refer to the value of the accumulated + # dimension. A = 1 at the start, and will end up being self.range_dimension. + # V will refer to the last dimension. V = R at the start, and will end up + # being 1 in the end. + for operator in self.operators[:-1]: + # Reshape output from [A, C, B, V] to be + # [A, C, B, V / op.domain_dimension, op.domain_dimension] + if adjoint: + operator_dimension = operator.range_dimension_tensor() + else: + operator_dimension = operator.domain_dimension_tensor() + + output = _unvec_by(output, operator_dimension) + + # We are computing (XA^-1^T) = (A^-1 X^T)^T. + # output has [A, C, B, V / op.domain_dimension, op.domain_dimension], + # which is being converted to: + # [A, C, B, V / op.domain_dimension, op.range_dimension] + output = array_ops.matrix_transpose(output) + output = operator.solve(output, adjoint=adjoint, adjoint_arg=False) + output = array_ops.matrix_transpose(output) + # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension] + output = _rotate_last_dim(output, rotate_right=False) + output = _vec(output) + output = _rotate_last_dim(output, rotate_right=True) + + # After the loop, we will have + # A = self.range_dimension / op[-1].range_dimension + # V = op[-1].domain_dimension + + # We convert that using matvec to get: + # [A, C, B, op[-1].range_dimension] + output = self.operators[-1].solvevec(output, adjoint=adjoint) + # Rearrange shape to be [B1, ... Bn, self.range_dimension, C] + output = _rotate_last_dim(output, rotate_right=False) + output = _vec(output) + output = _rotate_last_dim(output, rotate_right=False) + + if rhs.shape.is_fully_defined(): + column_dim = rhs.shape[-1] + broadcast_batch_shape = common_shapes.broadcast_shape( + rhs.shape[:-2], self.batch_shape) + if adjoint: + matrix_dimensions = [self.domain_dimension, column_dim] + else: + matrix_dimensions = [self.range_dimension, column_dim] + + output.set_shape(broadcast_batch_shape.concatenate( + matrix_dimensions)) + + return output + + def _diag_part(self): + diag_part = self.operators[0].diag_part() + for operator in self.operators[1:]: + diag_part = diag_part[..., :, array_ops.newaxis] + op_diag_part = operator.diag_part()[..., array_ops.newaxis, :] + diag_part *= op_diag_part + diag_part = array_ops.reshape( + diag_part, + shape=array_ops.concat( + [array_ops.shape(diag_part)[:-2], [-1]], axis=0)) + if self.range_dimension > self.domain_dimension: + diag_dimension = self.domain_dimension + else: + diag_dimension = self.range_dimension + diag_part.set_shape( + self.batch_shape.concatenate(diag_dimension)) + return diag_part + + def _to_dense(self): + product = self.operators[0].to_dense() + for operator in self.operators[1:]: + # Product has shape [B, R1, 1, C1]. + product = product[ + ..., :, array_ops.newaxis, :, array_ops.newaxis] + # Operator has shape [B, 1, R2, 1, C2]. + op_to_mul = operator.to_dense()[ + ..., array_ops.newaxis, :, array_ops.newaxis, :] + # This is now [B, R1, R2, C1, C2]. + product *= op_to_mul + # Now merge together dimensions to get [B, R1 * R2, C1 * C2]. + product = array_ops.reshape( + product, + shape=array_ops.concat( + [array_ops.shape(product)[:-4], + [array_ops.shape(product)[-4] * array_ops.shape(product)[-3], + array_ops.shape(product)[-2] * array_ops.shape(product)[-1]] + ], axis=0)) + product.set_shape(self.shape) + return product + + def _assert_non_singular(self): + if all(operator.is_square for operator in self.operators): + asserts = [operator.assert_non_singular() for operator in self.operators] + return control_flow_ops.group(asserts) + else: + raise errors.InvalidArgumentError( + node_def=None, op=None, message="All Kronecker factors must be " + "square for the product to be invertible.") + + def _assert_self_adjoint(self): + if all(operator.is_square for operator in self.operators): + asserts = [operator.assert_self_adjoint() for operator in self.operators] + return control_flow_ops.group(asserts) + else: + raise errors.InvalidArgumentError( + node_def=None, op=None, message="All Kronecker factors must be " + "square for the product to be self adjoint.") diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index cea3627ed565f0de86d8d9bb6b45c4b19c5b5558..5b89c6cef9fa9fdef7c26ddee1efa03f3056d881 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -138,14 +138,3 @@ py_test( "//third_party/py/numpy", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index 70f777f08bd5b8157e601f19019075d3e7543811..6e6c812adcbaf7e90d7c10c05fdfc0e150829329 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random import threading from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel @@ -102,6 +103,35 @@ def make_example_dict(example_protos, example_weights): example_ids=['%d' % i for i in range(0, len(example_protos))]) +def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): + random.seed(1) + + sparse_features = [ + SparseFeatureColumn( + [i for i in range(num_examples) for _ in range(num_non_zero)], [ + i for _ in range(num_examples) + for i in random.sample(range(dim), num_non_zero) + ], + [num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)]) + ] + examples_dict = dict( + sparse_features=sparse_features, + dense_features=[], + example_weights=[random.random() for _ in range(num_examples)], + example_labels=[ + 1. if random.random() > 0.5 else 0. for _ in range(num_examples) + ], + example_ids=[str(i) for i in range(num_examples)]) + + weights = variables_lib.Variable( + array_ops.zeros([dim], dtype=dtypes.float32)) + variables_dict = dict( + sparse_features_weights=[weights], + dense_features_weights=[]) + + return examples_dict, variables_dict + + def make_variable_dict(max_age, max_gender): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. @@ -235,6 +265,60 @@ class SdcaWithLogisticLossTest(SdcaModelTest): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSparseRandom(self): + dim = 20 + num_examples = 1000 + # Number of non-zero features per example. + non_zeros = 10 + # Setup test data. + with self._single_threaded_test_session(): + examples, variables = make_random_examples_and_variables_dicts( + num_examples, dim, non_zeros) + options = dict( + symmetric_l2_regularization=.1, + symmetric_l1_regularization=0, + num_table_shards=1, + adaptive=False, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + train_op = lr.minimize() + for _ in range(4): + train_op.run() + lr.update_weights(train_op).run() + # Duality gap is 1.4e-5. + # It would be 0.01 without shuffling and 0.02 with adaptive sampling. + self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-3) + + def testSparseDuplicate(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0] * 5, + 'gender': [0] * 5 + }, 0), + make_example_proto({ + 'age': [1] * 5, + 'gender': [1] * 5 + }, 1), + ] + example_weights = [1.0, 1.0] + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + train_op = lr.minimize() + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'Duplicate'): + train_op.run() + def testDistributedSimple(self): # Setup test data example_protos = [ @@ -270,14 +354,14 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def Minimize(): + def minimize(): with self._single_threaded_test_session(): for _ in range(_MAX_ITERATIONS): - train_op.run() + train_op.run() # pylint: disable=cell-var-from-loop threads = [] for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=Minimize)) + threads.append(threading.Thread(target=minimize)) threads[-1].start() for t in threads: @@ -395,7 +479,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllClose([0, 1, 1, 1], predicted_labels.eval()) self.assertAllClose( - 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + 0.0, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) def testFractionalExampleLabel(self): # Setup test data with 1 positive, and 1 mostly-negative example. @@ -407,7 +491,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest): make_example_proto({ 'age': [1], 'gender': [1] - }, 1), + }, 0.9), ] example_weights = [1.0, 1.0] for num_shards in _SHARD_NUMBERS: diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 3f5fdc18bb8f47cceee8f81dd5ded02059344b8b..f980746a19fb8e0a02b9d023c127da7ab33e457f 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -168,6 +168,10 @@ class SdcaModel(object): # of workers return self._options.get('num_loss_partitions', 1) + def _adaptive(self): + # Perform adaptive sampling. + return self._options.get('adaptive', True) + def _num_table_shards(self): # Number of hash table shards. # Return 1 if not specified or if the value is 'None' @@ -344,7 +348,8 @@ class SdcaModel(object): l1=self._options['symmetric_l1_regularization'], l2=self._symmetric_l2_regularization(), num_loss_partitions=self._num_loss_partitions(), - num_inner_iterations=1) + num_inner_iterations=1, + adaptative=self._adaptive()) # pylint: enable=protected-access with ops.control_dependencies([esu]): diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 05794a42c5f2d0eece6adab36fb5610078cece31..d4e54c82f988e0adcd16aad29702ee9f8b16aea3 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -140,8 +140,8 @@ def sdca_model_fn(features, labels, mode, params, config=None): parent_scope = "linear" - with variable_scope.variable_op_scope(features.values(), - parent_scope) as scope: + with variable_scope.variable_scope( + values=features.values(), name_or_scope=parent_scope) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index 79a5928a21cb9a2633b2aac178f185ba333790d6..bed3d5139fcbf9d9e8b85605c752736f26af6793 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -30,6 +30,13 @@ from tensorflow.python.platform import test class SDCALogisticClassifierTest(test.TestCase): + def _single_threaded_test_session(self): + # TODO(andreasst): figure out why SDCALinearRegressor needs a single + # threaded session to pass in tsan mode but SDCALogisticClassifier does not. + config = config_pb2.ConfigProto( + inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) + return self.test_session(config=config) + def testRealValuedFeatures(self): """Tests SDCALogisticClassifier works with real valued features.""" @@ -41,7 +48,7 @@ class SDCALogisticClassifierTest(test.TestCase): 'weights': constant_op.constant([[1.0], [1.0]]) }, constant_op.constant([[0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): maintenance_cost = feature_column_lib.real_valued_column( 'maintenance_cost') sq_footage = feature_column_lib.real_valued_column('sq_footage') @@ -66,7 +73,7 @@ class SDCALogisticClassifierTest(test.TestCase): constant_op.constant([[500.0, 800.0], [200.0, 600.0]]) }, constant_op.constant([[0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): dense_feature = feature_column_lib.real_valued_column( 'dense_feature', dimension=2) classifier = sdca_estimator.SDCALogisticClassifier( @@ -86,7 +93,7 @@ class SDCALogisticClassifierTest(test.TestCase): 'weights': constant_op.constant([[1.0], [1.0], [1.0]]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): price_bucket = feature_column_lib.bucketized_column( feature_column_lib.real_valued_column('price'), boundaries=[500.0, 700.0]) @@ -120,7 +127,7 @@ class SDCALogisticClassifierTest(test.TestCase): constant_op.constant([[1.0], [1.0], [1.0]]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): price = feature_column_lib.real_valued_column('price') country = feature_column_lib.sparse_column_with_hash_bucket( 'country', hash_bucket_size=5) @@ -151,7 +158,7 @@ class SDCALogisticClassifierTest(test.TestCase): dense_shape=[3, 5]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): country = feature_column_lib.sparse_column_with_hash_bucket( 'country', hash_bucket_size=5) country_weighted_by_price = feature_column_lib.weighted_sparse_column( @@ -163,6 +170,38 @@ class SDCALogisticClassifierTest(test.TestCase): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testSparseFeaturesWithDuplicates(self): + """Tests SDCALogisticClassifier with duplicated sparse features.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2']), + 'age': + sparse_tensor.SparseTensor( + values=['20-29'] * 5 + ['31-40'] * 5, + indices=[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [1, 0], + [1, 0], [1, 0], [1, 0], [1, 0]], + dense_shape=[2, 1]), + 'gender': + sparse_tensor.SparseTensor( + values=['m'] * 5 + ['f'] * 5, + indices=[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [1, 0], + [1, 0], [1, 0], [1, 0], [1, 0]], + dense_shape=[2, 1]), + }, constant_op.constant([[1], [0]]) + + with self._single_threaded_test_session(): + age = feature_column_lib.sparse_column_with_hash_bucket( + 'age', hash_bucket_size=10) + gender = feature_column_lib.sparse_column_with_hash_bucket( + 'gender', hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', feature_columns=[age, gender]) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertLess(metrics['loss'], 0.060) + def testCrossedFeatures(self): """Tests SDCALogisticClassifier with crossed features.""" @@ -182,7 +221,7 @@ class SDCALogisticClassifierTest(test.TestCase): dense_shape=[3, 1]) }, constant_op.constant([[0], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): language = feature_column_lib.sparse_column_with_hash_bucket( 'language', hash_bucket_size=5) country = feature_column_lib.sparse_column_with_hash_bucket( @@ -215,7 +254,7 @@ class SDCALogisticClassifierTest(test.TestCase): constant_op.constant([[3.0], [1.0], [1.0]]) }, constant_op.constant([[1], [0], [1]]) - with self.test_session(): + with self._single_threaded_test_session(): price = feature_column_lib.real_valued_column('price') sq_footage_bucket = feature_column_lib.bucketized_column( feature_column_lib.real_valued_column('sq_footage'), diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 92d022f2a30ffeb77e81d3bd01365afcd14826b5..5d4572bf6c761e0de2c9e6d7e17193abf0ebb170 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib import layers from tensorflow.contrib.linear_optimizer.python.ops import sdca_ops from tensorflow.contrib.linear_optimizer.python.ops.sparse_feature_column import SparseFeatureColumn +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -71,12 +72,14 @@ class SDCAOptimizer(object): num_loss_partitions=1, num_table_shards=None, symmetric_l1_regularization=0.0, - symmetric_l2_regularization=1.0): + symmetric_l2_regularization=1.0, + adaptive=True): self._example_id_column = example_id_column self._num_loss_partitions = num_loss_partitions self._num_table_shards = num_table_shards self._symmetric_l1_regularization = symmetric_l1_regularization self._symmetric_l2_regularization = symmetric_l2_regularization + self._adaptive = adaptive def get_name(self): return 'SDCAOptimizer' @@ -101,6 +104,10 @@ class SDCAOptimizer(object): def symmetric_l2_regularization(self): return self._symmetric_l2_regularization + @property + def adaptive(self): + return self._adaptive + def get_train_step(self, columns_to_variables, weight_column_name, loss_type, features, targets, global_step): """Returns the training operation of an SdcaModel optimizer.""" @@ -175,28 +182,42 @@ class SDCAOptimizer(object): elif isinstance( column, ( + layers.feature_column._WeightedSparseColumn, # pylint: disable=protected-access layers.feature_column._CrossedColumn, # pylint: disable=protected-access layers.feature_column._SparseColumn)): # pylint: disable=protected-access - sparse_features.append( - SparseFeatureColumn( - array_ops.reshape( - array_ops.split( - value=transformed_tensor.indices, - num_or_size_splits=2, - axis=1)[0], [-1]), - array_ops.reshape(transformed_tensor.values, [-1]), None)) - sparse_feature_weights.append(columns_to_variables[column][0]) - elif isinstance(column, layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access - id_tensor = column.id_tensor(transformed_tensor) - weight_tensor = column.weight_tensor(transformed_tensor) + + if isinstance(column, layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access + id_tensor = column.id_tensor(transformed_tensor) + weight_tensor = array_ops.reshape( + column.weight_tensor(transformed_tensor).values, [-1]) + else: + id_tensor = transformed_tensor + weight_tensor = array_ops.ones( + [array_ops.shape(id_tensor.indices)[0]], dtypes.float32) + + example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1]) + + flat_ids = array_ops.reshape(id_tensor.values, [-1]) + projection_length = math_ops.reduce_max(flat_ids) + 1 + # project ids based on example ids so that we can dedup ids that + # occur multiple times for a single example. + projected_ids = projection_length * example_ids + flat_ids + + # Remove any redudant ids. + ids, idx = array_ops.unique(projected_ids) + # Keep only one example id per duplicated ids. + example_ids_filtered = math_ops.unsorted_segment_min( + example_ids, idx, + array_ops.shape(ids)[0]) + + # reproject ids back feature id space. + reproject_ids = (ids - projection_length * example_ids_filtered) + + weights = array_ops.reshape( + math_ops.unsorted_segment_sum(weight_tensor, idx, + array_ops.shape(ids)[0]), [-1]) sparse_feature_with_values.append( - SparseFeatureColumn( - array_ops.reshape( - array_ops.split( - value=id_tensor.indices, num_or_size_splits=2, axis=1) - [0], [-1]), - array_ops.reshape(id_tensor.values, [-1]), - array_ops.reshape(weight_tensor.values, [-1]))) + SparseFeatureColumn(example_ids_filtered, reproject_ids, weights)) sparse_feature_with_values_weights.append( columns_to_variables[column][0]) else: @@ -228,6 +249,7 @@ class SDCAOptimizer(object): options=dict( symmetric_l1_regularization=self._symmetric_l1_regularization, symmetric_l2_regularization=self._symmetric_l2_regularization, + adaptive=self._adaptive, num_loss_partitions=self._num_loss_partitions, num_table_shards=self._num_table_shards, loss_type=loss_type)) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 44c4a7e2ca8d019ca602c7f2b492cd1e70b17561..1534f97d7600151e78c7fa7e8509d9e871240421 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -89,6 +89,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], + deps = [":context"], ) cc_library( @@ -132,10 +133,12 @@ cc_library( ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", + ":util", + "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/profiling:profiler", "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/core:lib_platform", ], ) @@ -169,6 +172,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", "//tensorflow/contrib/lite/testing:util", @@ -232,6 +236,27 @@ cc_test( ], ) +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + ":context", + ], +) + +cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + deps = [ + ":context", + ":util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the serialization of a model with optional tensors. # Model tests @@ -248,18 +273,3 @@ cc_test( # ], # }), #) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "downloads", - "examples", - "gen", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 7f316292724ea0baaf034d4e914773ad97a957d4..65fba52d461461f4594e2222ef6df3849b741f99 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -27,10 +27,10 @@ LIBDIR := $(MAKEFILE_DIR)/gen/lib/ GENDIR := $(MAKEFILE_DIR)/gen/obj/ # Settings for the host compiler. -CXX := $(CC_PREFIX) gcc +CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG -CC := $(CC_PREFIX) gcc -CFLAGS := +CC := $(CC_PREFIX)gcc +CFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -57,10 +57,11 @@ LIBS := \ # If we're on Linux, also link in the dl library. ifeq ($(HOST_OS),LINUX) - LIBS += -ldl -lpthread + LIBS += -ldl endif include $(MAKEFILE_DIR)/ios_makefile.inc +include $(MAKEFILE_DIR)/rpi_makefile.inc # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. @@ -89,7 +90,8 @@ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ -$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) +$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ +$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 00e93d2c4f3ab27057b855fba6fccf2ec8d7a1c1..a676b705f143b393c7e5bfa9e40d23f9adb68dcc 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -1,235 +1,8 @@ # TensorFlow Lite -TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. -TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. +TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded +devices. It enables low-latency inference of on-device machine learning models +with a small binary size and fast performance supporting hardware acceleration. -![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with an Android Demo App - -This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo. - -There are 3 ways to get the demo app to your device - - Download the prebuilt binary or - - Use Android Studio to build the application or - - Download the source code for TensorFlow Lite and the demo and build it using bazel - -## Description -In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. - -## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary -[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) - -Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. - -## Building in Android Studio using TensorFlow Lite AAR from JCenter -The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. - - - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). - - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). - - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. - - Click through installing all the Gradle extensions it requests. - - Either - - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/` - - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) - - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory - - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from - `classifier = new ImageClassifierQuantizedMobileNet(getActivity());` - to - `classifier = new ImageClassifierFloatInception(getActivity());` - - Build and run the demo app - -## Building TensorFlow Lite and the demo app from source - -### Clone the TensorFlow repo -- git clone - [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow) - -### Install Bazel -If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) - -NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. - -### Install Android NDK and SDK -Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. - - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) - - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). - - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). - - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` - -``` -android_sdk_repository ( - name = "androidsdk", - api_level = 23, - build_tools_version = "23.0.2", - path = "/home/xxxx/android-sdk-linux/", -) - -android_ndk_repository( - name = "androidndk", - path = "/home/xxxx/android-ndk-r10e/", - api_level = 19, -) -``` - -Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). - -### Build the source code -Run bazel with the following command to build the demo. - -Build the demo app: - -``` -bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo -``` - -### Note - -Currently, we only support building the Android demo app within a Python 2 -environment (due to a Bazel bug). - -### More about the demo -The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. - -# iOS Demo App - -Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). - -This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: - -1. Run `third_party/tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. -1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. -1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. -1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. - -# TensorFlow Lite Quick Start - -## Step 1. Decide which GraphDef to use - Depending on the use case, the developer may choose to use one of the popular - open-sourced models such as InceptionV3 or MobileNets, re-train these models - with their own custom data set or even build their own custom model. - -### Using a pre-trained model - -[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes. - -[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers. - -[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now. - -These pre-trained models can be downloaded from [here](g3doc/models.md). - -### Retrain Inception-V3 or MobileNet for a custom data set -The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes. - -The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. - - -### Train a custom model -A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. - -TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This -set will continue to expand in future releases of Tensorflow Lite. - - -## Step 2. Model format conversion - -The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below. - -A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph. - -Since we employ several formats, the following definitions may be useful: - - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions. - - - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted. - - - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint. - - - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. - - - TensorFlow lite model (.tflite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. - -### Freeze Graph -To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. - -The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). - -Graph freezing can be done using the command below (and modifying the arguments appropriately) - -``` -bazel build tensorflow/python/tools:freeze_graph - -bazel-bin/tensorflow/python/tools/freeze_graph\ - --input_graph=/tmp/mobilenet_v1_224.pb \ - --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ - --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ - --output_node_names=MobileNet/Predictions/Reshape_1 -``` - -The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with -graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). - -This frozen Graphdef is now ready to be converted to flatbuffer format (.tflite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. - -Here is a sample command line to convert the frozen Graphdef to '.tflite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. -(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). - -``` -bazel build tensorflow/contrib/lite/toco:toco - -bazel-bin/tensorflow/contrib/lite/toco/toco \ - --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ - --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ - --output_file=/tmp/mobilenet_v1_1.0_224.tflite --inference_type=FLOAT \ - --input_type=FLOAT --input_arrays=input \ - --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 -``` - -- The input_file argument should point to the frozen GraphDef file that holds the model architecture. -- The output_file argument should point to where the TensorFlow Lite model file should be generated. -- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. -- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. - -Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, - -```python -import tensorflow as tf - -img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) -out = tf.identity(val, name="out") -with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("converteds_model.tflite", "wb").write(tflite_model) - -``` -For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). - -You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). - -If you would like to see a visual description of your TensorFlow Lite model after conversion, you can use tensorflow/contrib/lite/tools/visualize.py by running -```sh -bazel run tensorflow/contrib/lite/tools:visualize -- model.tflite model_viz.html -``` -and then visualize the resulting HTML file in a browser. - -## Step 3. Use the TensorFlow Lite model for inference in a mobile app - -After completion of Step 2 the developer should have a .tflite model. - -### For Android -Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). - -The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). - -Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). - -### For iOS -Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. - -## Core ML support - -Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). +See the documentation: https://www.tensorflow.org/mobile/tflite/ +Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite) diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index 4b322e027d48f4bf9f90d5b873c449d1ec31cc49..a4772731ecda92431c412672610a39c188dabf27 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 87b17c338e7afc33d32dd9688cc0825ac319dd19..4f836d367747e06de682b5764206d33f6e2fb983 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/arena_planner.h" +#include namespace tflite { @@ -128,6 +129,11 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { + // Grow the size of `allocs_` if necessary. This allows allocating temporary + // tensors in op's `prepare` function. + TF_LITE_ENSURE(context_, graph_info_->num_tensors() >= allocs_.size()); + allocs_.resize(graph_info_->num_tensors()); + TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node)); TF_LITE_ENSURE_STATUS(Commit()); diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index 58bc164619c2c053b9492e9a0e5de2da30e199af..e9d0fbc5a9b5aec06e28da8757466b25f40da2f5 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -25,7 +25,7 @@ limitations under the License. namespace tflite { -class AllocationInfo; +struct AllocationInfo; // A memory planner that makes all the allocations using arenas. // @@ -33,7 +33,7 @@ class AllocationInfo; // each tensor needs to be allocated and deallocated, and preallocates all the // necessary memory (the PlanAllocations phase). It then assigns portions of // this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may -// share some of the bufer if a tensor B is to be allocated after another tensor +// share some of the buffer if a tensor B is to be allocated after another tensor // A has been deallocated. // // If dynamic tensors are used the planning steps can be repeated during model diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 19829e4991651111e13fc1805f97daef8bc016a7..b8f6b7fd59af9834edb4aa7aefa524c25ede66d2 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -104,7 +104,7 @@ def tflite_jni_binary(name, """Builds a jni binary for TFLite.""" linkopts = linkopts + [ "-Wl,--version-script", # Export only jni functions & classes. - linkscript, + "$(location {})".format(linkscript), ] native.cc_binary( name=name, @@ -200,8 +200,7 @@ def gen_zipped_test_files(name, files): native.genrule( name = name + "_" + f + ".files", cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco - + " --zip_to_output " + f + - " $(@D) zipped"), + + " --zip_to_output " + f + " $(@D)"), outs = [out_file], tools = [ ":generate_examples", diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 4a9023ff33de15dd384531d51e39de4ffeecdb8b..9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,11 +19,16 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ +$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ +$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a lipo \ tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ diff --git a/tensorflow/contrib/lite/build_rpi_lib.sh b/tensorflow/contrib/lite/build_rpi_lib.sh new file mode 100755 index 0000000000000000000000000000000000000000..3824b16412ed26a6cab79df3242da6017c3322b0 --- /dev/null +++ b/tensorflow/contrib/lite/build_rpi_lib.sh @@ -0,0 +1,22 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + +CC_PREFIX=arm-linux-gnueabihf- make -j 3 -f tensorflow/contrib/lite/Makefile TARGET=RPI TARGET_ARCH=armv7 diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 5dbeadd16582ec586adab100b8a46e10182bd5ee..4910c89eaebabb7bd9a4e003b75fa6de4d5af69d 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/context.h" + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -51,6 +53,8 @@ typedef struct { TfLitePadding padding; int stride_width; int stride_height; + int dilation_width_factor; + int dilation_height_factor; TfLiteFusedActivation activation; } TfLiteConvParams; @@ -174,6 +178,11 @@ typedef struct { int block_size; } TfLiteSpaceToDepthParams; +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + typedef enum { kTfLiteCombinerTypeSum = 0, kTfLiteCombinerTypeMean = 1, @@ -195,6 +204,10 @@ typedef struct { bool keep_dims; } TfLiteMeanParams; +typedef struct { + int num_splits; +} TfLiteSplitParams; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -210,6 +223,10 @@ typedef struct { int shrink_axis_mask; } TfLiteStridedSliceParams; +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..859bc7ab70dc363e08800ca5c40eb0da6ca426b0 --- /dev/null +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ + +// DO NOT EDIT MANUALLY: This file is automatically generated by +// `schema_builtin_ops_header_generator.py`. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// The enum for builtin operators. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real built-in ops. +typedef enum { + kTfLiteBuiltinAdd = 0, + kTfLiteBuiltinAveragePool2d = 1, + kTfLiteBuiltinConcatenation = 2, + kTfLiteBuiltinConv2d = 3, + kTfLiteBuiltinDepthwiseConv2d = 4, + kTfLiteBuiltinDequantize = 6, + kTfLiteBuiltinEmbeddingLookup = 7, + kTfLiteBuiltinFullyConnected = 9, + kTfLiteBuiltinHashtableLookup = 10, + kTfLiteBuiltinL2Normalization = 11, + kTfLiteBuiltinL2Pool2d = 12, + kTfLiteBuiltinLocalResponseNormalization = 13, + kTfLiteBuiltinLogistic = 14, + kTfLiteBuiltinLshProjection = 15, + kTfLiteBuiltinLstm = 16, + kTfLiteBuiltinMaxPool2d = 17, + kTfLiteBuiltinMul = 18, + kTfLiteBuiltinRelu = 19, + kTfLiteBuiltinReluN1To1 = 20, + kTfLiteBuiltinRelu6 = 21, + kTfLiteBuiltinReshape = 22, + kTfLiteBuiltinResizeBilinear = 23, + kTfLiteBuiltinRnn = 24, + kTfLiteBuiltinSoftmax = 25, + kTfLiteBuiltinSpaceToDepth = 26, + kTfLiteBuiltinSvdf = 27, + kTfLiteBuiltinTanh = 28, + kTfLiteBuiltinConcatEmbeddings = 29, + kTfLiteBuiltinSkipGram = 30, + kTfLiteBuiltinCall = 31, + kTfLiteBuiltinCustom = 32, + kTfLiteBuiltinEmbeddingLookupSparse = 33, + kTfLiteBuiltinPad = 34, + kTfLiteBuiltinUnidirectionalSequenceRnn = 35, + kTfLiteBuiltinGather = 36, + kTfLiteBuiltinBatchToSpaceNd = 37, + kTfLiteBuiltinSpaceToBatchNd = 38, + kTfLiteBuiltinTranspose = 39, + kTfLiteBuiltinMean = 40, + kTfLiteBuiltinSub = 41, + kTfLiteBuiltinDiv = 42, + kTfLiteBuiltinSqueeze = 43, + kTfLiteBuiltinUnidirectionalSequenceLstm = 44, + kTfLiteBuiltinStridedSlice = 45, + kTfLiteBuiltinBidirectionalSequenceRnn = 46, + kTfLiteBuiltinExp = 47, + kTfLiteBuiltinTopkV2 = 48, + kTfLiteBuiltinSplit = 49, + kTfLiteBuiltinLogSoftmax = 50, + kTfLiteBuiltinDelegate = 51, + kTfLiteBuiltinBidirectionalSequenceLstm = 52, + kTfLiteBuiltinCast = 53, + kTfLiteBuiltinPrelu = 54, + kTfLiteBuiltinMaximum = 55, + kTfLiteBuiltinArgMax = 56, + kTfLiteBuiltinMinimum = 57, + kTfLiteBuiltinLess = 58, +} TfLiteBuiltinOperator; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +} diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index c09e838c5c2e50e0f4a38eaf66e55246fd9a6f7f..5c6f5e72a47180cd98be46f60cfa8eaf28197806 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -17,9 +17,14 @@ limitations under the License. #include #include +int TfLiteIntArrayGetSizeInBytes(int size) { + static TfLiteIntArray dummy; + return sizeof(dummy) + sizeof(dummy.data[0]) * size; +} + TfLiteIntArray* TfLiteIntArrayCreate(int size) { TfLiteIntArray* ret = - (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + (TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size)); ret->size = size; return ret; } @@ -55,12 +60,16 @@ TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } -void TfLiteTensorFree(TfLiteTensor* t) { +void TfLiteTensorDataFree(TfLiteTensor* t) { if (t->allocation_type == kTfLiteDynamic && t->data.raw) { free(t->data.raw); } - if (t->dims) TfLiteIntArrayFree(t->dims); t->data.raw = NULL; +} + +void TfLiteTensorFree(TfLiteTensor* t) { + TfLiteTensorDataFree(t); + if (t->dims) TfLiteIntArrayFree(t->dims); t->dims = NULL; } diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index b0c4d3431f9a67bc87d51ada91ed73f1661023a2..0b38f43cd32fbdfa0296eec7ef81aab76ebe5461 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -29,6 +29,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#include #include #include @@ -40,6 +41,7 @@ typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; // Forward declare so GetNode can use this is in Context. typedef struct _TfLiteRegistration TfLiteRegistration; +typedef struct _TfLiteDelegate TfLiteDelegate; #define kOptionalTensor (-1) @@ -57,6 +59,10 @@ typedef struct { #endif } TfLiteIntArray; +// Given the size (number of elements) in a TfLiteIntArray, calculate its size +// in bytes. +int TfLiteIntArrayGetSizeInBytes(int size); + // Create a array of a given `size` (uninitialized entries). // This returns a pointer, that you must free using TfLiteIntArrayFree(). TfLiteIntArray* TfLiteIntArrayCreate(int size); @@ -131,6 +137,7 @@ typedef enum { kTfLiteUInt8 = 3, kTfLiteInt64 = 4, kTfLiteString = 5, + kTfLiteBool = 6, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -149,6 +156,7 @@ typedef union { char* raw; const char* raw_const; uint8_t* uint8; + bool* b; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -162,6 +170,11 @@ typedef enum { kTfLiteDynamic, } TfLiteAllocationType; +// The delegates should use zero or positive integers to represent handles. +// -1 is reserved from unallocated status. +typedef int TfLiteBufferHandle; +const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; + // An tensor in the interpreter system which is a wrapper around a buffer of // data including a dimensionality (or NULL if not currently defined). typedef struct { @@ -194,8 +207,27 @@ typedef struct { // Null-terminated name of this tensor. const char* name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; } TfLiteTensor; +// Free data memory of tensor `t`; +void TfLiteTensorDataFree(TfLiteTensor* t); + // Free memory of tensor `t`; void TfLiteTensorFree(TfLiteTensor* t); @@ -234,6 +266,11 @@ typedef struct { // WARNING: This is an experimental interface that is subject to change. const void* custom_initial_data; int custom_initial_data_size; + + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; } TfLiteNode; typedef struct TfLiteContext { @@ -258,7 +295,7 @@ typedef struct TfLiteContext { TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, TfLiteIntArray** execution_plan); - // An tensor of tensors in the interpreter context (of length `tensors_size`) + // An array of tensors in the interpreter context (of length `tensors_size`) TfLiteTensor* tensors; // opaque full context ptr (an opaque c++ data structure) @@ -283,14 +320,20 @@ typedef struct TfLiteContext { TfLiteNode** node, TfLiteRegistration** registration); - // Replace ops with delegate. + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( struct TfLiteContext*, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace); + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; // TODO(ahentz): we should create a more general mechanism for this sort of // library-global objects. void* gemm_context; + void* eigen_context; } TfLiteContext; typedef struct _TfLiteRegistration { @@ -337,19 +380,47 @@ typedef struct _TfLiteRegistration { } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. -typedef struct { +typedef struct _TfLiteDelegate { // Data that delegate needs to identify itself. This data is owned by the // delegate. The delegate is owned in the user code, so the delegate is // responsible for doing this when it is destroyed. void* data_; + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the // delegate a view of the current graph through TfLiteContext*. It typically // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() // to ask the TensorFlow lite runtime to create macro-nodes to represent // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(TfLiteContext* context, void* data); + TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + + // Copy the data from delegate buffer handle to raw memory. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, int size); + + // Copy the data from raw memory to delegate buffer handle. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, int size); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteDelegate* delegate, + TfLiteBufferHandle* handle); } TfLiteDelegate; +// WARNING: This is an experimental interface that is subject to change. +typedef struct { + TfLiteDelegate* delegate; + TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; +} TfLiteDelegateParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index a93ed201d647ddf2359a57254a959871c13fb94f..436c3e1d4cad5e6ee355d7e9cf8ee7da1a8385ce 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,12 +30,15 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once +# the archive has been propagated in mirror.bazel.build. +GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -91,6 +94,7 @@ download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index da193d2586e9123341b9a41be049ee2a4382017a..3c5f805f12f6a1fb7185c140604f692ac282a143 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -30,7 +30,7 @@ namespace tflite { // va_list args; // foo.Report("test %d", args); // where args is va_list // -// Sublclass ErrorReporter to provide another reporting destination. +// Subclass ErrorReporter to provide another reporting destination. // For example, if you have a GUI program, you might redirect to a buffer // that drives a GUI error log box. class ErrorReporter { diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..bc9574d646b7661de8ac9b745bd53cbba1eb9f31 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/AndroidManifest.xml @@ -0,0 +1,65 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..49280129971e38247c2216d9422bc5de9176e13d --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -0,0 +1,86 @@ +# Description: +# TensorFlow camera demo app for Android. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +# Build the demo native demo lib from the original directory to reduce code +# reuse. Note that the Java counterparts (ObjectTracker.java and +# ImageUtils.java) are still duplicated. +cc_library( + name = "tensorflow_native_libs", + srcs = [ + "//tensorflow/examples/android:libtensorflow_demo.so", + ], + tags = [ + "manual", + "notap", + ], +) + +android_binary( + name = "tflite_demo", + srcs = glob([ + "src/**/*.java", + ]), + # Package assets from assets dir as well as all model targets. + # Remove undesired models (and corresponding Activities in source) + # to reduce APK size. + assets = [ + "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", + "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", + "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + ], + assets_dir = "", + custom_package = "org.tensorflow.lite.demo", + inline_constants = 1, + manifest = "AndroidManifest.xml", + manifest_merger = "android", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = [ + "manual", + "notap", + ], + deps = [ + ":tensorflow_native_libs", + "//tensorflow/contrib/lite/java:tensorflowlite", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "bin/**", + "gen/**", + "gradleBuild/**", + "libs/**", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +filegroup( + name = "java_files", + srcs = glob(["src/**/*.java"]), +) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), +) + +exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd6c98ff878e9c41875cab74c12191cadb173 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/assets/box_priors.txt new file mode 100644 index 0000000000000000000000000000000000000000..7246b073fe7fd8b1d1340536457c8aeac24cd5a3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/box_priors.txt @@ -0,0 +1,5 @@ + 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.02631579 0.02631579 0.026315793 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.078947365 0.07894737 0.078947365 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.13157895 0.13157895 0.13157894 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.18421052 0.18421051 0.18421052 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.23684211 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.28947368 0.28947368 0.28947365 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.34210524 0.34210524 0.3421052 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.92105263 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.97368425 0.9736843 0.97368425 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.049999997 0.049999997 0.049999997 0.05 0.050000012 0.049999997 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.25 0.25 0.25 0.25 0.25000003 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000005 0.35000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.45 0.45000002 0.45000002 0.45000002 0.45000002 0.45000002 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.55 0.55 0.55 0.55 0.54999995 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.099999994 0.1 0.099999994 0.1 0.099999994 0.099999994 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.30000004 0.3 0.3 0.3 0.3 0.30000004 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.49999997 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.90000004 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.16666667 0.16666667 0.16666666 0.16666667 0.16666669 0.16666667 0.5 0.5 0.49999997 0.5 0.5 0.5 0.5 0.5 0.49999997 0.5 0.5 0.5 0.5 0.5 0.49999997 0.5 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.8333334 0.25 0.25 0.25 0.24999999 0.25 0.25 0.25 0.25 0.25 0.24999999 0.25 0.25 0.75 0.75 0.75 0.75 0.74999994 0.75 0.75 0.75 0.75 0.75 0.74999994 0.75 0.5 0.5 0.5 0.5 0.5 0.5 + 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.02631579 0.026315793 0.02631579 0.078947365 0.078947365 0.07894737 0.13157895 0.13157894 0.13157895 0.18421052 0.18421052 0.18421051 0.23684211 0.23684211 0.23684211 0.28947368 0.28947365 0.28947368 0.34210524 0.3421052 0.34210524 0.39473683 0.39473683 0.39473683 0.4473684 0.4473684 0.4473684 0.5 0.5 0.5 0.5526316 0.5526316 0.5526316 0.6052632 0.6052632 0.6052632 0.65789473 0.65789473 0.65789473 0.71052635 0.71052635 0.71052635 0.7631579 0.7631579 0.7631579 0.8157895 0.8157895 0.8157895 0.8684211 0.8684211 0.8684211 0.92105263 0.92105263 0.92105263 0.97368425 0.97368425 0.9736843 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.049999997 0.049999997 0.050000004 0.050000012 0.05 0.049999997 0.15 0.14999999 0.15 0.15 0.15 0.15 0.25 0.25 0.25 0.25 0.25 0.25 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.35000002 0.45000002 0.45 0.45000002 0.45000002 0.45 0.45000002 0.55 0.55 0.55 0.55 0.55 0.55 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.65000004 0.75 0.75 0.75 0.75 0.75 0.75 0.85 0.85 0.85 0.85 0.85 0.85 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.95000005 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.10000001 0.099999994 0.1 0.099999994 0.1 0.099999994 0.3 0.3 0.3 0.29999998 0.3 0.30000004 0.5 0.5 0.5 0.5 0.5 0.49999997 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.70000005 0.9 0.90000004 0.90000004 0.9 0.90000004 0.90000004 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.16666667 0.16666669 0.16666667 0.16666669 0.16666667 0.16666667 0.49999997 0.5 0.5 0.50000006 0.5 0.5 0.8333334 0.8333334 0.8333334 0.8333333 0.8333334 0.8333334 0.25 0.25 0.25 0.25 0.25 0.25 0.75 0.75 0.75 0.75 0.75 0.75 0.25 0.25 0.25 0.25 0.25 0.25 0.75 0.75 0.75 0.75 0.75 0.75 0.5 0.5 0.5 0.5 0.5 0.5 + 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.1 0.14142136 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.28284273 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142138 0.2828427 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.099999994 0.14142135 0.28284273 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.10000001 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142138 0.2828427 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142135 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.2828427 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142132 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.100000024 0.14142138 0.28284276 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.2474874 0.4949748 0.20207259 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748738 0.4949748 0.20207258 0.6062481 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.4949748 0.2020726 0.60624814 0.41833 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35000002 0.24748741 0.49497482 0.2020726 0.60624814 0.41832998 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35 0.24748737 0.4949748 0.20207256 0.6062481 0.41833 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497476 0.20207262 0.606248 0.41833004 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000002 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.35000008 0.24748743 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.34999996 0.24748737 0.49497485 0.20207262 0.60624814 0.41832995 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.50000006 0.3535534 0.7071068 0.28867513 0.8660687 0.57008773 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5000001 0.3535534 0.7071068 0.28867513 0.8660687 0.5700878 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.2886751 0.8660687 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5 0.3535534 0.7071068 0.28867507 0.8660688 0.5700877 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.5000001 0.3535534 0.70710677 0.2886752 0.8660687 0.5700878 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.65000004 0.45961943 0.91923887 0.37527767 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.4596194 0.9192388 0.37527764 1.1258893 0.7211102 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.6500001 0.45961946 0.9192388 0.3752777 1.1258893 0.72111017 0.8000001 0.5656855 1.131371 0.4618802 1.3857099 0.8717798 0.8000001 0.5656855 1.131371 0.4618802 1.3857099 0.8717798 0.80000013 0.5656855 1.131371 0.4618802 1.3857098 0.87177986 0.80000013 0.5656855 1.131371 0.4618802 1.3857098 0.87177986 0.95000005 0.6717515 1.343503 0.5484828 1.6455305 0.97467947 + 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.1 0.28284273 0.14142136 0.099999994 0.28284273 0.14142138 0.099999994 0.2828427 0.14142138 0.099999994 0.28284273 0.14142135 0.099999994 0.28284273 0.14142135 0.10000001 0.2828427 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142138 0.100000024 0.2828427 0.14142138 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142135 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.2828427 0.14142132 0.100000024 0.28284276 0.14142132 0.100000024 0.28284276 0.14142138 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.34999996 0.4949747 0.24748735 0.60621774 0.20206249 0.41833 0.34999996 0.49497467 0.24748737 0.60621774 0.20206249 0.41833 0.34999996 0.49497473 0.24748737 0.60621774 0.20206249 0.41833 0.34999993 0.49497473 0.24748737 0.60621774 0.20206249 0.41832998 0.34999996 0.49497467 0.24748737 0.60621774 0.20206246 0.41833 0.35 0.49497473 0.24748734 0.60621774 0.20206249 0.41833004 0.35 0.49497473 0.2474873 0.60621774 0.20206249 0.41833004 0.3499999 0.49497473 0.2474873 0.6062178 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062177 0.20206249 0.41832995 0.3499999 0.49497467 0.2474873 0.6062178 0.20206255 0.41832995 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.49999997 0.7071067 0.35355335 0.8660254 0.2886607 0.57008773 0.5 0.7071067 0.35355335 0.8660253 0.2886607 0.5700878 0.5 0.7071067 0.35355332 0.86602545 0.28866073 0.5700877 0.5 0.70710665 0.3535533 0.86602545 0.28866076 0.5700877 0.49999994 0.7071067 0.3535534 0.8660253 0.28866065 0.5700878 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.6499999 0.9192387 0.45961934 1.1258329 0.3752589 0.7211102 0.64999986 0.9192387 0.4596193 1.125833 0.37525892 0.7211102 0.64999986 0.91923875 0.45961928 1.1258328 0.37525892 0.72111017 0.79999995 1.1313708 0.5656854 1.3856406 0.46185714 0.8717798 0.79999995 1.1313708 0.56568533 1.3856406 0.46185708 0.87177986 0.79999995 1.1313708 0.5656854 1.3856406 0.46185714 0.8717798 0.79999995 1.1313708 0.56568533 1.3856406 0.46185708 0.87177986 0.9499999 1.3435028 0.6717514 1.6454482 0.54845536 0.97467947 + diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt new file mode 100644 index 0000000000000000000000000000000000000000..5a70ff82aa7b0fa7315ca591820e4cf7d2f5ad18 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt @@ -0,0 +1,91 @@ +??? +person +bicycle +car +motorcycle +airplane +bus +train +truck +boat +traffic light +fire hydrant +??? +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +??? +backpack +umbrella +??? +??? +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +??? +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +couch +potted plant +bed +??? +dining table +??? +??? +toilet +??? +tv +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +??? +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..ba416458b011a7f4b96739eb6fcb6275a6ab3bec --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt @@ -0,0 +1,12 @@ +_silence_ +_unknown_ +yes +no +up +down +left +right +on +off +stop +go \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..0d4de358156a5d139e35cc542b8d36ab24e763b9 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -0,0 +1,52 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml new file mode 100644 index 0000000000000000000000000000000000000000..891d8cc1d4f3e59d0371030fd763c5ad468e7887 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml @@ -0,0 +1,30 @@ + + + + + diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..32bd1aabcabb85ded957230533c00e735183a323 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..b3113cd15c3255405ee34c622a1e83674e6e5487 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000000000000000000000000000000000..135862883e26eddce2b19db021adf62e10357ad0 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..8efbbf8b3c44418551699db9388cd77a88362112 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..51f87ee6507cebec6bff32b1a03b36ffc711689d Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..ba143ea7a80f03b0e850775ad672ccb2d6195e4c Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..6361d792dacd8ce09a14258878b5ce6db5e0debb Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..394eb7e534905e36fd24c3defac92c09b403ee39 Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..2e27bec9785d4d51fe597bced7f04508994aa10c Binary files /dev/null and b/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/res/drawable/border.xml new file mode 100644 index 0000000000000000000000000000000000000000..dd1d64d1d61f359422c79533f726991c78e47d99 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/drawable/border.xml @@ -0,0 +1,19 @@ + + + + + diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..1a22d4b33ebbd755104272863c5cc6c93793b86b --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml new file mode 100644 index 0000000000000000000000000000000000000000..2fe1338da57122c7e26c64c653076b6746a25497 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml @@ -0,0 +1,55 @@ + + + + + + + +